mm/memory_hotplug.c: tidy up node_states_clear_node()
[sfrench/cifs-2.6.git] / mm / mempolicy.c
index da858f794eb694a934afb58da49f749343ed8c08..cfd26d7e61a17f9c5fd260b85778058aa04b83e2 100644 (file)
@@ -797,16 +797,19 @@ static void get_policy_nodemask(struct mempolicy *p, nodemask_t *nodes)
        }
 }
 
-static int lookup_node(unsigned long addr)
+static int lookup_node(struct mm_struct *mm, unsigned long addr)
 {
        struct page *p;
        int err;
 
-       err = get_user_pages(addr & PAGE_MASK, 1, 0, &p, NULL);
+       int locked = 1;
+       err = get_user_pages_locked(addr & PAGE_MASK, 1, 0, &p, &locked);
        if (err >= 0) {
                err = page_to_nid(p);
                put_page(p);
        }
+       if (locked)
+               up_read(&mm->mmap_sem);
        return err;
 }
 
@@ -817,7 +820,7 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
        int err;
        struct mm_struct *mm = current->mm;
        struct vm_area_struct *vma = NULL;
-       struct mempolicy *pol = current->mempolicy;
+       struct mempolicy *pol = current->mempolicy, *pol_refcount = NULL;
 
        if (flags &
                ~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR|MPOL_F_MEMS_ALLOWED))
@@ -857,7 +860,16 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 
        if (flags & MPOL_F_NODE) {
                if (flags & MPOL_F_ADDR) {
-                       err = lookup_node(addr);
+                       /*
+                        * Take a refcount on the mpol, lookup_node()
+                        * wil drop the mmap_sem, so after calling
+                        * lookup_node() only "pol" remains valid, "vma"
+                        * is stale.
+                        */
+                       pol_refcount = pol;
+                       vma = NULL;
+                       mpol_get(pol);
+                       err = lookup_node(mm, addr);
                        if (err < 0)
                                goto out;
                        *policy = err;
@@ -892,7 +904,9 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
  out:
        mpol_cond_put(pol);
        if (vma)
-               up_read(&current->mm->mmap_sem);
+               up_read(&mm->mmap_sem);
+       if (pol_refcount)
+               mpol_put(pol_refcount);
        return err;
 }
 
@@ -2697,12 +2711,11 @@ static const char * const policy_modes[] =
 int mpol_parse_str(char *str, struct mempolicy **mpol)
 {
        struct mempolicy *new = NULL;
-       unsigned short mode;
        unsigned short mode_flags;
        nodemask_t nodes;
        char *nodelist = strchr(str, ':');
        char *flags = strchr(str, '=');
-       int err = 1;
+       int err = 1, mode;
 
        if (nodelist) {
                /* NUL-terminate mode or flags string */
@@ -2717,12 +2730,8 @@ int mpol_parse_str(char *str, struct mempolicy **mpol)
        if (flags)
                *flags++ = '\0';        /* terminate mode string */
 
-       for (mode = 0; mode < MPOL_MAX; mode++) {
-               if (!strcmp(str, policy_modes[mode])) {
-                       break;
-               }
-       }
-       if (mode >= MPOL_MAX)
+       mode = match_string(policy_modes, MPOL_MAX, str);
+       if (mode < 0)
                goto out;
 
        switch (mode) {