Merge branch 'timers-fixes-for-linus' of git://git.kernel.org/pub/scm/linux/kernel...
[sfrench/cifs-2.6.git] / net / ipv4 / udp.c
index cf5ab0581ebac98ddf919f92083f7066aa0bf61f..c47c989cb1fb162b9fb71f8f27a580f4ba8088d0 100644 (file)
@@ -120,8 +120,11 @@ EXPORT_SYMBOL(sysctl_udp_wmem_min);
 atomic_t udp_memory_allocated;
 EXPORT_SYMBOL(udp_memory_allocated);
 
+#define PORTS_PER_CHAIN (65536 / UDP_HTABLE_SIZE)
+
 static int udp_lib_lport_inuse(struct net *net, __u16 num,
                               const struct udp_hslot *hslot,
+                              unsigned long *bitmap,
                               struct sock *sk,
                               int (*saddr_comp)(const struct sock *sk1,
                                                 const struct sock *sk2))
@@ -132,12 +135,17 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
        sk_nulls_for_each(sk2, node, &hslot->head)
                if (net_eq(sock_net(sk2), net)                  &&
                    sk2 != sk                                   &&
-                   sk2->sk_hash == num                         &&
+                   (bitmap || sk2->sk_hash == num)             &&
                    (!sk2->sk_reuse || !sk->sk_reuse)           &&
                    (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if
                        || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
-                   (*saddr_comp)(sk, sk2))
-                       return 1;
+                   (*saddr_comp)(sk, sk2)) {
+                       if (bitmap)
+                               __set_bit(sk2->sk_hash / UDP_HTABLE_SIZE,
+                                         bitmap);
+                       else
+                               return 1;
+               }
        return 0;
 }
 
@@ -160,32 +168,47 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
        if (!snum) {
                int low, high, remaining;
                unsigned rand;
-               unsigned short first;
+               unsigned short first, last;
+               DECLARE_BITMAP(bitmap, PORTS_PER_CHAIN);
 
                inet_get_local_port_range(&low, &high);
                remaining = (high - low) + 1;
 
                rand = net_random();
-               snum = first = rand % remaining + low;
-               rand |= 1;
-               for (;;) {
-                       hslot = &udptable->hash[udp_hashfn(net, snum)];
+               first = (((u64)rand * remaining) >> 32) + low;
+               /*
+                * force rand to be an odd multiple of UDP_HTABLE_SIZE
+                */
+               rand = (rand | 1) * UDP_HTABLE_SIZE;
+               for (last = first + UDP_HTABLE_SIZE; first != last; first++) {
+                       hslot = &udptable->hash[udp_hashfn(net, first)];
+                       bitmap_zero(bitmap, PORTS_PER_CHAIN);
                        spin_lock_bh(&hslot->lock);
-                       if (!udp_lib_lport_inuse(net, snum, hslot, sk, saddr_comp))
-                               break;
-                       spin_unlock_bh(&hslot->lock);
+                       udp_lib_lport_inuse(net, snum, hslot, bitmap, sk,
+                                           saddr_comp);
+
+                       snum = first;
+                       /*
+                        * Iterate on all possible values of snum for this hash.
+                        * Using steps of an odd multiple of UDP_HTABLE_SIZE
+                        * give us randomization and full range coverage.
+                        */
                        do {
-                               snum = snum + rand;
-                       } while (snum < low || snum > high);
-                       if (snum == first)
-                               goto fail;
+                               if (low <= snum && snum <= high &&
+                                   !test_bit(snum / UDP_HTABLE_SIZE, bitmap))
+                                       goto found;
+                               snum += rand;
+                       } while (snum != first);
+                       spin_unlock_bh(&hslot->lock);
                }
+               goto fail;
        } else {
                hslot = &udptable->hash[udp_hashfn(net, snum)];
                spin_lock_bh(&hslot->lock);
-               if (udp_lib_lport_inuse(net, snum, hslot, sk, saddr_comp))
+               if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk, saddr_comp))
                        goto fail_unlock;
        }
+found:
        inet_sk(sk)->num = snum;
        sk->sk_hash = snum;
        if (sk_unhashed(sk)) {
@@ -992,9 +1015,11 @@ static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
 
        if ((rc = sock_queue_rcv_skb(sk, skb)) < 0) {
                /* Note that an ENOMEM error is charged twice */
-               if (rc == -ENOMEM)
+               if (rc == -ENOMEM) {
                        UDP_INC_STATS_BH(sock_net(sk), UDP_MIB_RCVBUFERRORS,
                                         is_udplite);
+                       atomic_inc(&sk->sk_drops);
+               }
                goto drop;
        }
 
@@ -1206,11 +1231,10 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
                   int proto)
 {
        struct sock *sk;
-       struct udphdr *uh = udp_hdr(skb);
+       struct udphdr *uh;
        unsigned short ulen;
        struct rtable *rt = (struct rtable*)skb->dst;
-       __be32 saddr = ip_hdr(skb)->saddr;
-       __be32 daddr = ip_hdr(skb)->daddr;
+       __be32 saddr, daddr;
        struct net *net = dev_net(skb->dev);
 
        /*
@@ -1219,6 +1243,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
        if (!pskb_may_pull(skb, sizeof(struct udphdr)))
                goto drop;              /* No space for header. */
 
+       uh   = udp_hdr(skb);
        ulen = ntohs(uh->len);
        if (ulen > skb->len)
                goto short_packet;
@@ -1233,6 +1258,9 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
        if (udp4_csum_init(skb, uh, proto))
                goto csum_error;
 
+       saddr = ip_hdr(skb)->saddr;
+       daddr = ip_hdr(skb)->daddr;
+
        if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST))
                return __udp4_lib_mcast_deliver(net, skb, uh,
                                saddr, daddr, udptable);