inet: frags: better deal with smp races
[sfrench/cifs-2.6.git] / net / ipv4 / inet_fragment.c
index bcb11f3a27c0c34115af05034a5a20f57842eb0a..760a9e52e02b91b36af323c92f7027e150858f88 100644 (file)
@@ -178,21 +178,22 @@ static struct inet_frag_queue *inet_frag_alloc(struct netns_frags *nf,
 }
 
 static struct inet_frag_queue *inet_frag_create(struct netns_frags *nf,
-                                               void *arg)
+                                               void *arg,
+                                               struct inet_frag_queue **prev)
 {
        struct inet_frags *f = nf->f;
        struct inet_frag_queue *q;
-       int err;
 
        q = inet_frag_alloc(nf, f, arg);
-       if (!q)
+       if (!q) {
+               *prev = ERR_PTR(-ENOMEM);
                return NULL;
-
+       }
        mod_timer(&q->timer, jiffies + nf->timeout);
 
-       err = rhashtable_insert_fast(&nf->rhashtable, &q->node,
-                                    f->rhash_params);
-       if (err < 0) {
+       *prev = rhashtable_lookup_get_insert_key(&nf->rhashtable, &q->key,
+                                                &q->node, f->rhash_params);
+       if (*prev) {
                q->flags |= INET_FRAG_COMPLETE;
                inet_frag_kill(q);
                inet_frag_destroy(q);
@@ -204,22 +205,22 @@ static struct inet_frag_queue *inet_frag_create(struct netns_frags *nf,
 /* TODO : call from rcu_read_lock() and no longer use refcount_inc_not_zero() */
 struct inet_frag_queue *inet_frag_find(struct netns_frags *nf, void *key)
 {
-       struct inet_frag_queue *fq;
+       struct inet_frag_queue *fq = NULL, *prev;
 
        if (!nf->high_thresh || frag_mem_limit(nf) > nf->high_thresh)
                return NULL;
 
        rcu_read_lock();
 
-       fq = rhashtable_lookup(&nf->rhashtable, key, nf->f->rhash_params);
-       if (fq) {
+       prev = rhashtable_lookup(&nf->rhashtable, key, nf->f->rhash_params);
+       if (!prev)
+               fq = inet_frag_create(nf, key, &prev);
+       if (prev && !IS_ERR(prev)) {
+               fq = prev;
                if (!refcount_inc_not_zero(&fq->refcnt))
                        fq = NULL;
-               rcu_read_unlock();
-               return fq;
        }
        rcu_read_unlock();
-
-       return inet_frag_create(nf, key);
+       return fq;
 }
 EXPORT_SYMBOL(inet_frag_find);