cifs: avoid potential races when handling multiple dfs tcons
[sfrench/cifs-2.6.git] / fs / cifs / dfs_cache.c
index 6557d7b2798a034397e2dfb5f794e7ae14a52dd2..1513b2709889b7abc6333cd3406f97c8e5afc0e9 100644 (file)
 #include "cifs_unicode.h"
 #include "smb2glob.h"
 #include "dns_resolve.h"
+#include "dfs.h"
 
 #include "dfs_cache.h"
 
-#define CACHE_HTABLE_SIZE 32
-#define CACHE_MAX_ENTRIES 64
-#define CACHE_MIN_TTL 120 /* 2 minutes */
+#define CACHE_HTABLE_SIZE      32
+#define CACHE_MAX_ENTRIES      64
+#define CACHE_MIN_TTL          120 /* 2 minutes */
+#define CACHE_DEFAULT_TTL      300 /* 5 minutes */
 
 #define IS_DFS_INTERLINK(v) (((v) & DFSREF_REFERRAL_SERVER) && !((v) & DFSREF_STORAGE_SERVER))
 
@@ -50,10 +52,9 @@ struct cache_entry {
 };
 
 static struct kmem_cache *cache_slab __read_mostly;
-static struct workqueue_struct *dfscache_wq __read_mostly;
+struct workqueue_struct *dfscache_wq;
 
-static int cache_ttl;
-static DEFINE_SPINLOCK(cache_ttl_lock);
+atomic_t dfs_cache_ttl;
 
 static struct nls_table *cache_cp;
 
@@ -65,10 +66,6 @@ static atomic_t cache_count;
 static struct hlist_head cache_htable[CACHE_HTABLE_SIZE];
 static DECLARE_RWSEM(htable_rw_lock);
 
-static void refresh_cache_worker(struct work_struct *work);
-
-static DECLARE_DELAYED_WORK(refresh_task, refresh_cache_worker);
-
 /**
  * dfs_cache_canonical_path - get a canonical DFS path
  *
@@ -290,7 +287,9 @@ int dfs_cache_init(void)
        int rc;
        int i;
 
-       dfscache_wq = alloc_workqueue("cifs-dfscache", WQ_FREEZABLE | WQ_UNBOUND, 1);
+       dfscache_wq = alloc_workqueue("cifs-dfscache",
+                                     WQ_UNBOUND|WQ_FREEZABLE|WQ_MEM_RECLAIM,
+                                     0);
        if (!dfscache_wq)
                return -ENOMEM;
 
@@ -306,6 +305,7 @@ int dfs_cache_init(void)
                INIT_HLIST_HEAD(&cache_htable[i]);
 
        atomic_set(&cache_count, 0);
+       atomic_set(&dfs_cache_ttl, CACHE_DEFAULT_TTL);
        cache_cp = load_nls("utf8");
        if (!cache_cp)
                cache_cp = load_nls_default();
@@ -480,6 +480,7 @@ static struct cache_entry *add_cache_entry_locked(struct dfs_info3_param *refs,
        int rc;
        struct cache_entry *ce;
        unsigned int hash;
+       int ttl;
 
        WARN_ON(!rwsem_is_locked(&htable_rw_lock));
 
@@ -496,15 +497,8 @@ static struct cache_entry *add_cache_entry_locked(struct dfs_info3_param *refs,
        if (IS_ERR(ce))
                return ce;
 
-       spin_lock(&cache_ttl_lock);
-       if (!cache_ttl) {
-               cache_ttl = ce->ttl;
-               queue_delayed_work(dfscache_wq, &refresh_task, cache_ttl * HZ);
-       } else {
-               cache_ttl = min_t(int, cache_ttl, ce->ttl);
-               mod_delayed_work(dfscache_wq, &refresh_task, cache_ttl * HZ);
-       }
-       spin_unlock(&cache_ttl_lock);
+       ttl = min_t(int, atomic_read(&dfs_cache_ttl), ce->ttl);
+       atomic_set(&dfs_cache_ttl, ttl);
 
        hlist_add_head(&ce->hlist, &cache_htable[hash]);
        dump_ce(ce);
@@ -616,7 +610,6 @@ static struct cache_entry *lookup_cache_entry(const char *path)
  */
 void dfs_cache_destroy(void)
 {
-       cancel_delayed_work_sync(&refresh_task);
        unload_nls(cache_cp);
        flush_cache_ents();
        kmem_cache_destroy(cache_slab);
@@ -1142,6 +1135,7 @@ static bool target_share_equal(struct TCP_Server_Info *server, const char *s1, c
  * target shares in @refs.
  */
 static void mark_for_reconnect_if_needed(struct TCP_Server_Info *server,
+                                        const char *path,
                                         struct dfs_cache_tgt_list *old_tl,
                                         struct dfs_cache_tgt_list *new_tl)
 {
@@ -1153,8 +1147,10 @@ static void mark_for_reconnect_if_needed(struct TCP_Server_Info *server,
                     nit = dfs_cache_get_next_tgt(new_tl, nit)) {
                        if (target_share_equal(server,
                                               dfs_cache_get_tgt_name(oit),
-                                              dfs_cache_get_tgt_name(nit)))
+                                              dfs_cache_get_tgt_name(nit))) {
+                               dfs_cache_noreq_update_tgthint(path, nit);
                                return;
+                       }
                }
        }
 
@@ -1162,13 +1158,28 @@ static void mark_for_reconnect_if_needed(struct TCP_Server_Info *server,
        cifs_signal_cifsd_for_reconnect(server, true);
 }
 
+static bool is_ses_good(struct cifs_ses *ses)
+{
+       struct TCP_Server_Info *server = ses->server;
+       struct cifs_tcon *tcon = ses->tcon_ipc;
+       bool ret;
+
+       spin_lock(&ses->ses_lock);
+       spin_lock(&ses->chan_lock);
+       ret = !cifs_chan_needs_reconnect(ses, server) &&
+               ses->ses_status == SES_GOOD &&
+               !tcon->need_reconnect;
+       spin_unlock(&ses->chan_lock);
+       spin_unlock(&ses->ses_lock);
+       return ret;
+}
+
 /* Refresh dfs referral of tcon and mark it for reconnect if needed */
-static int __refresh_tcon(const char *path, struct cifs_tcon *tcon, bool force_refresh)
+static int __refresh_tcon(const char *path, struct cifs_ses *ses, bool force_refresh)
 {
        struct dfs_cache_tgt_list old_tl = DFS_CACHE_TGT_LIST_INIT(old_tl);
        struct dfs_cache_tgt_list new_tl = DFS_CACHE_TGT_LIST_INIT(new_tl);
-       struct cifs_ses *ses = CIFS_DFS_ROOT_SES(tcon->ses);
-       struct cifs_tcon *ipc = ses->tcon_ipc;
+       struct TCP_Server_Info *server = ses->server;
        bool needs_refresh = false;
        struct cache_entry *ce;
        unsigned int xid;
@@ -1190,20 +1201,19 @@ static int __refresh_tcon(const char *path, struct cifs_tcon *tcon, bool force_r
                goto out;
        }
 
-       spin_lock(&ipc->tc_lock);
-       if (ipc->status != TID_GOOD) {
-               spin_unlock(&ipc->tc_lock);
-               cifs_dbg(FYI, "%s: skip cache refresh due to disconnected ipc\n", __func__);
+       ses = CIFS_DFS_ROOT_SES(ses);
+       if (!is_ses_good(ses)) {
+               cifs_dbg(FYI, "%s: skip cache refresh due to disconnected ipc\n",
+                        __func__);
                goto out;
        }
-       spin_unlock(&ipc->tc_lock);
 
        ce = cache_refresh_path(xid, ses, path, true);
        if (!IS_ERR(ce)) {
                rc = get_targets(ce, &new_tl);
                up_read(&htable_rw_lock);
                cifs_dbg(FYI, "%s: get_targets: %d\n", __func__, rc);
-               mark_for_reconnect_if_needed(tcon->ses->server, &old_tl, &new_tl);
+               mark_for_reconnect_if_needed(server, path, &old_tl, &new_tl);
        }
 
 out:
@@ -1216,10 +1226,11 @@ out:
 static int refresh_tcon(struct cifs_tcon *tcon, bool force_refresh)
 {
        struct TCP_Server_Info *server = tcon->ses->server;
+       struct cifs_ses *ses = tcon->ses;
 
        mutex_lock(&server->refpath_lock);
        if (server->leaf_fullpath)
-               __refresh_tcon(server->leaf_fullpath + 1, tcon, force_refresh);
+               __refresh_tcon(server->leaf_fullpath + 1, ses, force_refresh);
        mutex_unlock(&server->refpath_lock);
        return 0;
 }
@@ -1263,60 +1274,32 @@ int dfs_cache_remount_fs(struct cifs_sb_info *cifs_sb)
        return refresh_tcon(tcon, true);
 }
 
-/*
- * Worker that will refresh DFS cache from all active mounts based on lowest TTL value
- * from a DFS referral.
- */
-static void refresh_cache_worker(struct work_struct *work)
+/* Refresh all DFS referrals related to DFS tcon */
+void dfs_cache_refresh(struct work_struct *work)
 {
        struct TCP_Server_Info *server;
-       struct cifs_tcon *tcon, *ntcon;
-       struct list_head tcons;
+       struct dfs_root_ses *rses;
+       struct cifs_tcon *tcon;
        struct cifs_ses *ses;
 
-       INIT_LIST_HEAD(&tcons);
+       tcon = container_of(work, struct cifs_tcon, dfs_cache_work.work);
+       ses = tcon->ses;
+       server = ses->server;
 
-       spin_lock(&cifs_tcp_ses_lock);
-       list_for_each_entry(server, &cifs_tcp_ses_list, tcp_ses_list) {
-               spin_lock(&server->srv_lock);
-               if (!server->leaf_fullpath) {
-                       spin_unlock(&server->srv_lock);
-                       continue;
-               }
-               spin_unlock(&server->srv_lock);
-
-               list_for_each_entry(ses, &server->smb_ses_list, smb_ses_list) {
-                       if (ses->tcon_ipc) {
-                               ses->ses_count++;
-                               list_add_tail(&ses->tcon_ipc->ulist, &tcons);
-                       }
-                       list_for_each_entry(tcon, &ses->tcon_list, tcon_list) {
-                               if (!tcon->ipc) {
-                                       tcon->tc_count++;
-                                       list_add_tail(&tcon->ulist, &tcons);
-                               }
-                       }
-               }
-       }
-       spin_unlock(&cifs_tcp_ses_lock);
-
-       list_for_each_entry_safe(tcon, ntcon, &tcons, ulist) {
-               struct TCP_Server_Info *server = tcon->ses->server;
-
-               list_del_init(&tcon->ulist);
+       mutex_lock(&server->refpath_lock);
+       if (server->leaf_fullpath)
+               __refresh_tcon(server->leaf_fullpath + 1, ses, false);
+       mutex_unlock(&server->refpath_lock);
 
+       list_for_each_entry(rses, &tcon->dfs_ses_list, list) {
+               ses = rses->ses;
+               server = ses->server;
                mutex_lock(&server->refpath_lock);
                if (server->leaf_fullpath)
-                       __refresh_tcon(server->leaf_fullpath + 1, tcon, false);
+                       __refresh_tcon(server->leaf_fullpath + 1, ses, false);
                mutex_unlock(&server->refpath_lock);
-
-               if (tcon->ipc)
-                       cifs_put_smb_ses(tcon->ses);
-               else
-                       cifs_put_tcon(tcon);
        }
 
-       spin_lock(&cache_ttl_lock);
-       queue_delayed_work(dfscache_wq, &refresh_task, cache_ttl * HZ);
-       spin_unlock(&cache_ttl_lock);
+       queue_delayed_work(dfscache_wq, &tcon->dfs_cache_work,
+                          atomic_read(&dfs_cache_ttl) * HZ);
 }