[UDP]: Fix AF-specific references in AF-agnostic code.
authorDavid S. Miller <davem@sunset.davemloft.net>
Wed, 9 May 2007 23:42:20 +0000 (16:42 -0700)
committerDavid S. Miller <davem@sunset.davemloft.net>
Fri, 11 May 2007 06:47:22 +0000 (23:47 -0700)
__udp_lib_port_inuse() cannot make direct references to
inet_sk(sk)->rcv_saddr as that is ipv4 specific state and
this code is used by ipv6 too.

Use an operations vector to solve this, and this also paves
the way for ipv6 support for non-wild saddr hashing in UDP.

Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/udp.h
include/net/udplite.h
net/ipv4/udp.c
net/ipv4/udp_impl.h
net/ipv4/udplite.c
net/ipv6/udp.c
net/ipv6/udp_impl.h
net/ipv6/udplite.c

index 98755ebaf163cfe788b4117167cfeaf1a861f269..496f89d45c8b89eabf7b57bde71016e85552a677 100644 (file)
@@ -119,9 +119,16 @@ static inline void udp_lib_close(struct sock *sk, long timeout)
 }
 
 
+struct udp_get_port_ops {
+       int (*saddr_cmp)(const struct sock *sk1, const struct sock *sk2);
+       int (*saddr_any)(const struct sock *sk);
+       unsigned int (*hash_port_and_rcv_saddr)(__u16 port,
+                                               const struct sock *sk);
+};
+
 /* net/ipv4/udp.c */
 extern int     udp_get_port(struct sock *sk, unsigned short snum,
-                            int (*saddr_cmp)(const struct sock *, const struct sock *));
+                            const struct udp_get_port_ops *ops);
 extern void    udp_err(struct sk_buff *, u32);
 
 extern int     udp_sendmsg(struct kiocb *iocb, struct sock *sk,
index 635b0eafca95d256e789edb0323cbb8aef759f4f..50b4b424d1caab47e79673eb3fa6972c45651e91 100644 (file)
@@ -120,5 +120,5 @@ static inline __wsum udplite_csum_outgoing(struct sock *sk, struct sk_buff *skb)
 
 extern void    udplite4_register(void);
 extern int     udplite_get_port(struct sock *sk, unsigned short snum,
-                       int (*scmp)(const struct sock *, const struct sock *));
+                                const struct udp_get_port_ops *ops);
 #endif /* _UDPLITE_H */
index 66026df1cc7639bcba7f9b0232b8a4267107e7b4..4c7e95fa090d181234e3dbb1a2d1934a259c317f 100644 (file)
@@ -118,15 +118,15 @@ static int udp_port_rover;
  * Note about this hash function :
  * Typical use is probably daddr = 0, only dport is going to vary hash
  */
-static inline unsigned int hash_port_and_addr(__u16 port, __be32 addr)
+static inline unsigned int udp_hash_port(__u16 port)
 {
-       addr ^= addr >> 16;
-       addr ^= addr >> 8;
-       return port ^ addr;
+       return port;
 }
 
 static inline int __udp_lib_port_inuse(unsigned int hash, int port,
-       __be32 daddr, struct hlist_head udptable[])
+                                      const struct sock *this_sk,
+                                      struct hlist_head udptable[],
+                                      const struct udp_get_port_ops *ops)
 {
        struct sock *sk;
        struct hlist_node *node;
@@ -138,7 +138,10 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port,
                inet = inet_sk(sk);
                if (inet->num != port)
                        continue;
-               if (inet->rcv_saddr == daddr)
+               if (this_sk) {
+                       if (ops->saddr_cmp(sk, this_sk))
+                               return 1;
+               } else if (ops->saddr_any(sk))
                        return 1;
        }
        return 0;
@@ -151,12 +154,11 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port,
  *  @snum:        port number to look up
  *  @udptable:    hash list table, must be of UDP_HTABLE_SIZE
  *  @port_rover:  pointer to record of last unallocated port
- *  @saddr_comp:  AF-dependent comparison of bound local IP addresses
+ *  @ops:         AF-dependent address operations
  */
 int __udp_lib_get_port(struct sock *sk, unsigned short snum,
                       struct hlist_head udptable[], int *port_rover,
-                      int (*saddr_comp)(const struct sock *sk1,
-                                        const struct sock *sk2 )    )
+                      const struct udp_get_port_ops *ops)
 {
        struct hlist_node *node;
        struct hlist_head *head;
@@ -176,8 +178,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
                for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) {
                        int size;
 
-                       hash = hash_port_and_addr(result,
-                                       inet_sk(sk)->rcv_saddr);
+                       hash = ops->hash_port_and_rcv_saddr(result, sk);
                        head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
                        if (hlist_empty(head)) {
                                if (result > sysctl_local_port_range[1])
@@ -203,17 +204,16 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
                                result = sysctl_local_port_range[0]
                                        + ((result - sysctl_local_port_range[0]) &
                                           (UDP_HTABLE_SIZE - 1));
-                       hash = hash_port_and_addr(result, 0);
+                       hash = udp_hash_port(result);
                        if (__udp_lib_port_inuse(hash, result,
-                                                0, udptable))
+                                                NULL, udptable, ops))
                                continue;
-                       if (!inet_sk(sk)->rcv_saddr)
+                       if (ops->saddr_any(sk))
                                break;
 
-                       hash = hash_port_and_addr(result,
-                                       inet_sk(sk)->rcv_saddr);
+                       hash = ops->hash_port_and_rcv_saddr(result, sk);
                        if (! __udp_lib_port_inuse(hash, result,
-                               inet_sk(sk)->rcv_saddr, udptable))
+                                                  sk, udptable, ops))
                                break;
                }
                if (i >= (1 << 16) / UDP_HTABLE_SIZE)
@@ -221,7 +221,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
 gotit:
                *port_rover = snum = result;
        } else {
-               hash = hash_port_and_addr(snum, 0);
+               hash = udp_hash_port(snum);
                head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
 
                sk_for_each(sk2, node, head)
@@ -231,12 +231,11 @@ gotit:
                            (!sk2->sk_reuse || !sk->sk_reuse) &&
                            (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
                             sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
-                           (*saddr_comp)(sk, sk2))
+                           ops->saddr_cmp(sk, sk2))
                                goto fail;
 
-               if (inet_sk(sk)->rcv_saddr) {
-                       hash = hash_port_and_addr(snum,
-                                                 inet_sk(sk)->rcv_saddr);
+               if (!ops->saddr_any(sk)) {
+                       hash = ops->hash_port_and_rcv_saddr(snum, sk);
                        head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
 
                        sk_for_each(sk2, node, head)
@@ -248,7 +247,7 @@ gotit:
                                     !sk->sk_bound_dev_if ||
                                     sk2->sk_bound_dev_if ==
                                     sk->sk_bound_dev_if) &&
-                                   (*saddr_comp)(sk, sk2))
+                                   ops->saddr_cmp(sk, sk2))
                                        goto fail;
                }
        }
@@ -266,12 +265,12 @@ fail:
 }
 
 int udp_get_port(struct sock *sk, unsigned short snum,
-                       int (*scmp)(const struct sock *, const struct sock *))
+                const struct udp_get_port_ops *ops)
 {
-       return  __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, scmp);
+       return  __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, ops);
 }
 
-int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
+static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
 {
        struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
 
@@ -280,9 +279,33 @@ int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
                   inet1->rcv_saddr == inet2->rcv_saddr      ));
 }
 
+static int ipv4_rcv_saddr_any(const struct sock *sk)
+{
+       return !inet_sk(sk)->rcv_saddr;
+}
+
+static inline unsigned int ipv4_hash_port_and_addr(__u16 port, __be32 addr)
+{
+       addr ^= addr >> 16;
+       addr ^= addr >> 8;
+       return port ^ addr;
+}
+
+static unsigned int ipv4_hash_port_and_rcv_saddr(__u16 port,
+                                                const struct sock *sk)
+{
+       return ipv4_hash_port_and_addr(port, inet_sk(sk)->rcv_saddr);
+}
+
+const struct udp_get_port_ops udp_ipv4_ops = {
+       .saddr_cmp = ipv4_rcv_saddr_equal,
+       .saddr_any = ipv4_rcv_saddr_any,
+       .hash_port_and_rcv_saddr = ipv4_hash_port_and_rcv_saddr,
+};
+
 static inline int udp_v4_get_port(struct sock *sk, unsigned short snum)
 {
-       return udp_get_port(sk, snum, ipv4_rcv_saddr_equal);
+       return udp_get_port(sk, snum, &udp_ipv4_ops);
 }
 
 /* UDP is nearly always wildcards out the wazoo, it makes no sense to try
@@ -297,8 +320,8 @@ static struct sock *__udp4_lib_lookup(__be32 saddr, __be16 sport,
        unsigned int hash, hashwild;
        int score, best = -1, hport = ntohs(dport);
 
-       hash = hash_port_and_addr(hport, daddr);
-       hashwild = hash_port_and_addr(hport, 0);
+       hash = ipv4_hash_port_and_addr(hport, daddr);
+       hashwild = udp_hash_port(hport);
 
        read_lock(&udp_hash_lock);
 
@@ -1198,8 +1221,8 @@ static int __udp4_lib_mcast_deliver(struct sk_buff *skb,
        struct sock *sk, *skw, *sknext;
        int dif;
        int hport = ntohs(uh->dest);
-       unsigned int hash = hash_port_and_addr(hport, daddr);
-       unsigned int hashwild = hash_port_and_addr(hport, 0);
+       unsigned int hash = ipv4_hash_port_and_addr(hport, daddr);
+       unsigned int hashwild = udp_hash_port(hport);
 
        dif = skb->dev->ifindex;
 
index 820a477cfaa6e9899234b03ee8120c9ca158babd..06d94195e644848fa464f54456c4abb450496a25 100644 (file)
@@ -5,14 +5,14 @@
 #include <net/protocol.h>
 #include <net/inet_common.h>
 
+extern const struct udp_get_port_ops udp_ipv4_ops;
+
 extern int     __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int );
 extern void    __udp4_lib_err(struct sk_buff *, u32, struct hlist_head []);
 
 extern int     __udp_lib_get_port(struct sock *sk, unsigned short snum,
                                   struct hlist_head udptable[], int *port_rover,
-                                  int (*)(const struct sock*,const struct sock*));
-extern int     ipv4_rcv_saddr_equal(const struct sock *, const struct sock *);
-
+                                  const struct udp_get_port_ops *ops);
 
 extern int     udp_setsockopt(struct sock *sk, int level, int optname,
                               char __user *optval, int optlen);
index f34fd686a8f15ac6e49b687742701c1037ab1417..3653b32dce2d79f6ec18eb2a484aaf325448578a 100644 (file)
@@ -19,14 +19,15 @@ struct hlist_head   udplite_hash[UDP_HTABLE_SIZE];
 static int             udplite_port_rover;
 
 int udplite_get_port(struct sock *sk, unsigned short p,
-                    int (*c)(const struct sock *, const struct sock *))
+                    const struct udp_get_port_ops *ops)
 {
-       return  __udp_lib_get_port(sk, p, udplite_hash, &udplite_port_rover, c);
+       return  __udp_lib_get_port(sk, p, udplite_hash,
+                                  &udplite_port_rover, ops);
 }
 
 static int udplite_v4_get_port(struct sock *sk, unsigned short snum)
 {
-       return udplite_get_port(sk, snum, ipv4_rcv_saddr_equal);
+       return udplite_get_port(sk, snum, &udp_ipv4_ops);
 }
 
 static int udplite_rcv(struct sk_buff *skb)
index b083c09e3d2d1b4757042e08684a10d7dc48d649..a7ae59c954d5a2d68f7d18c6bdb3addf1abcbdaf 100644 (file)
 
 DEFINE_SNMP_STAT(struct udp_mib, udp_stats_in6) __read_mostly;
 
+static int ipv6_rcv_saddr_any(const struct sock *sk)
+{
+       struct ipv6_pinfo *np = inet6_sk(sk);
+
+       return ipv6_addr_any(&np->rcv_saddr);
+}
+
+static unsigned int ipv6_hash_port_and_rcv_saddr(__u16 port,
+                                                const struct sock *sk)
+{
+       return port;
+}
+
+const struct udp_get_port_ops udp_ipv6_ops = {
+       .saddr_cmp = ipv6_rcv_saddr_equal,
+       .saddr_any = ipv6_rcv_saddr_any,
+       .hash_port_and_rcv_saddr = ipv6_hash_port_and_rcv_saddr,
+};
+
 static inline int udp_v6_get_port(struct sock *sk, unsigned short snum)
 {
-       return udp_get_port(sk, snum, ipv6_rcv_saddr_equal);
+       return udp_get_port(sk, snum, &udp_ipv6_ops);
 }
 
 static struct sock *__udp6_lib_lookup(struct in6_addr *saddr, __be16 sport,
index 6e252f318f7c91c81f5cdf04d2822b476f26a8ae..36b0c11a28a312dcad2bfa6f82d4bbd057dd307e 100644 (file)
@@ -6,6 +6,8 @@
 #include <net/addrconf.h>
 #include <net/inet_common.h>
 
+extern const struct udp_get_port_ops udp_ipv6_ops;
+
 extern int     __udp6_lib_rcv(struct sk_buff **, struct hlist_head [], int );
 extern void    __udp6_lib_err(struct sk_buff *, struct inet6_skb_parm *,
                               int , int , int , __be32 , struct hlist_head []);
index f54016a55004d2d931471c06f14db922efa5f84d..c40a51362f89ed6861d64acccdf371712ca7c61b 100644 (file)
@@ -37,7 +37,7 @@ static struct inet6_protocol udplitev6_protocol = {
 
 static int udplite_v6_get_port(struct sock *sk, unsigned short snum)
 {
-       return udplite_get_port(sk, snum, ipv6_rcv_saddr_equal);
+       return udplite_get_port(sk, snum, &udp_ipv6_ops);
 }
 
 struct proto udplitev6_prot = {