lib: Move "message_send_all" to serverid.c
[nivanova/samba-autobuild/.git] / source3 / lib / serverid.c
index b1f6a5711ef65abac9d7939f600c9cca143066bf..f2c64001d7407bf92c1be0d69c50b310013474a1 100644 (file)
 #include "system/filesys.h"
 #include "serverid.h"
 #include "util_tdb.h"
-#include "dbwrap.h"
-#include "lib/util/tdb_wrap.h"
+#include "dbwrap/dbwrap.h"
+#include "dbwrap/dbwrap_open.h"
+#include "lib/tdb_wrap/tdb_wrap.h"
+#include "lib/param/param.h"
+#include "ctdbd_conn.h"
+#include "messages.h"
+#include "lib/messages_dgm.h"
 
 struct serverid_key {
        pid_t pid;
+       uint32_t task_id;
        uint32_t vnn;
 };
 
@@ -34,37 +40,40 @@ struct serverid_data {
        uint32_t msg_flags;
 };
 
-bool serverid_parent_init(TALLOC_CTX *mem_ctx)
+static struct db_context *serverid_db(void)
 {
-       struct tdb_wrap *db;
+       static struct db_context *db;
+       char *db_path;
 
-       /*
-        * Open the tdb in the parent process (smbd) so that our
-        * CLEAR_IF_FIRST optimization in tdb_reopen_all can properly
-        * work.
-        */
+       if (db != NULL) {
+               return db;
+       }
 
-       db = tdb_wrap_open(mem_ctx, lock_path("serverid.tdb"),
-                          0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT,
-                          0644);
-       if (db == NULL) {
-               DEBUG(1, ("could not open serverid.tdb: %s\n",
-                         strerror(errno)));
-               return false;
+       db_path = lock_path("serverid.tdb");
+       if (db_path == NULL) {
+               return NULL;
        }
-       return true;
+
+       db = db_open(NULL, db_path, 0,
+                    TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH,
+                    O_RDWR|O_CREAT, 0644, DBWRAP_LOCK_ORDER_2,
+                    DBWRAP_FLAG_NONE);
+       TALLOC_FREE(db_path);
+       return db;
 }
 
-static struct db_context *serverid_db(void)
+bool serverid_parent_init(TALLOC_CTX *mem_ctx)
 {
-       static struct db_context *db;
+       struct db_context *db;
 
-       if (db != NULL) {
-               return db;
+       db = serverid_db();
+       if (db == NULL) {
+               DEBUG(1, ("could not open serverid.tdb: %s\n",
+                         strerror(errno)));
+               return false;
        }
-       db = db_open(NULL, lock_path("serverid.tdb"), 0,
-                    TDB_DEFAULT|TDB_CLEAR_IF_FIRST|TDB_INCOMPATIBLE_HASH, O_RDWR|O_CREAT, 0644);
-       return db;
+
+       return true;
 }
 
 static void serverid_fill_key(const struct server_id *id,
@@ -72,6 +81,7 @@ static void serverid_fill_key(const struct server_id *id,
 {
        ZERO_STRUCTP(key);
        key->pid = id->pid;
+       key->task_id = id->task_id;
        key->vnn = id->vnn;
 }
 
@@ -93,7 +103,7 @@ bool serverid_register(const struct server_id id, uint32_t msg_flags)
        serverid_fill_key(&id, &key);
        tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
 
-       rec = db->fetch_locked(db, talloc_tos(), tdbkey);
+       rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
        if (rec == NULL) {
                DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
                return false;
@@ -104,64 +114,18 @@ bool serverid_register(const struct server_id id, uint32_t msg_flags)
        data.msg_flags = msg_flags;
 
        tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
-       status = rec->store(rec, tdbdata, 0);
+       status = dbwrap_record_store(rec, tdbdata, 0);
        if (!NT_STATUS_IS_OK(status)) {
                DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
                          nt_errstr(status)));
                goto done;
        }
-       ret = true;
-done:
-       TALLOC_FREE(rec);
-       return ret;
-}
-
-bool serverid_register_msg_flags(const struct server_id id, bool do_reg,
-                                uint32_t msg_flags)
-{
-       struct db_context *db;
-       struct serverid_key key;
-       struct serverid_data *data;
-       struct db_record *rec;
-       TDB_DATA tdbkey;
-       NTSTATUS status;
-       bool ret = false;
-
-       db = serverid_db();
-       if (db == NULL) {
-               return false;
-       }
-
-       serverid_fill_key(&id, &key);
-       tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
-
-       rec = db->fetch_locked(db, talloc_tos(), tdbkey);
-       if (rec == NULL) {
-               DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
-               return false;
-       }
-
-       if (rec->value.dsize != sizeof(struct serverid_data)) {
-               DEBUG(1, ("serverid record has unexpected size %d "
-                         "(wanted %d)\n", (int)rec->value.dsize,
-                         (int)sizeof(struct serverid_data)));
-               goto done;
-       }
-
-       data = (struct serverid_data *)rec->value.dptr;
 
-       if (do_reg) {
-               data->msg_flags |= msg_flags;
-       } else {
-               data->msg_flags &= ~msg_flags;
+       if (lp_clustering()) {
+               register_with_ctdbd(messaging_ctdbd_connection(), id.unique_id,
+                                   NULL, NULL);
        }
 
-       status = rec->store(rec, rec->value, 0);
-       if (!NT_STATUS_IS_OK(status)) {
-               DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
-                         nt_errstr(status)));
-               goto done;
-       }
        ret = true;
 done:
        TALLOC_FREE(rec);
@@ -185,13 +149,13 @@ bool serverid_deregister(struct server_id id)
        serverid_fill_key(&id, &key);
        tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
 
-       rec = db->fetch_locked(db, talloc_tos(), tdbkey);
+       rec = dbwrap_fetch_locked(db, talloc_tos(), tdbkey);
        if (rec == NULL) {
                DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
                return false;
        }
 
-       status = rec->delete_rec(rec);
+       status = dbwrap_record_delete(rec);
        if (!NT_STATUS_IS_OK(status)) {
                DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
                          nt_errstr(status)));
@@ -203,55 +167,40 @@ done:
        return ret;
 }
 
-struct serverid_exists_state {
-       const struct server_id *id;
-       bool exists;
-};
-
-static int server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
+static bool serverid_exists_local(const struct server_id *id)
 {
-       struct serverid_exists_state *state =
-               (struct serverid_exists_state *)priv;
+       bool exists = process_exists_by_pid(id->pid);
+       uint64_t unique;
+       int ret;
 
-       if (data.dsize != sizeof(struct serverid_data)) {
-               return -1;
+       if (!exists) {
+               return false;
        }
 
-       /*
-        * Use memcmp, not direct compare. data.dptr might not be
-        * aligned.
-        */
-       state->exists = (memcmp(&state->id->unique_id, data.dptr,
-                               sizeof(state->id->unique_id)) == 0);
-       return 0;
-}
-
-bool serverid_exists(const struct server_id *id)
-{
-       struct db_context *db;
-       struct serverid_exists_state state;
-       struct serverid_key key;
-       TDB_DATA tdbkey;
-
-       if (lp_clustering() && !process_exists(*id)) {
-               return false;
+       if (id->unique_id == SERVERID_UNIQUE_ID_NOT_TO_VERIFY) {
+               return true;
        }
 
-       db = serverid_db();
-       if (db == NULL) {
+       ret = messaging_dgm_get_unique(id->pid, &unique);
+       if (ret != 0) {
                return false;
        }
 
-       serverid_fill_key(id, &key);
-       tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));
+       return (unique == id->unique_id);
+}
 
-       state.id = id;
-       state.exists = false;
+bool serverid_exists(const struct server_id *id)
+{
+       if (procid_is_local(id)) {
+               return serverid_exists_local(id);
+       }
 
-       if (db->parse_record(db, tdbkey, server_exists_parse, &state) == -1) {
-               return false;
+       if (lp_clustering()) {
+               return ctdbd_process_exists(messaging_ctdbd_connection(),
+                                           id->vnn, id->pid);
        }
-       return state.exists;
+
+       return false;
 }
 
 static bool serverid_rec_parse(const struct db_record *rec,
@@ -259,22 +208,28 @@ static bool serverid_rec_parse(const struct db_record *rec,
 {
        struct serverid_key key;
        struct serverid_data data;
+       TDB_DATA tdbkey;
+       TDB_DATA tdbdata;
+
+       tdbkey = dbwrap_record_get_key(rec);
+       tdbdata = dbwrap_record_get_value(rec);
 
-       if (rec->key.dsize != sizeof(key)) {
+       if (tdbkey.dsize != sizeof(key)) {
                DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
-                         (int)rec->key.dsize));
+                         (int)tdbkey.dsize));
                return false;
        }
-       if (rec->value.dsize != sizeof(data)) {
+       if (tdbdata.dsize != sizeof(data)) {
                DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
-                         (int)rec->value.dsize));
+                         (int)tdbdata.dsize));
                return false;
        }
 
-       memcpy(&key, rec->key.dptr, sizeof(key));
-       memcpy(&data, rec->value.dptr, sizeof(data));
+       memcpy(&key, tdbkey.dptr, sizeof(key));
+       memcpy(&data, tdbdata.dptr, sizeof(data));
 
        id->pid = key.pid;
+       id->task_id = key.task_id;
        id->vnn = key.vnn;
        id->unique_id = data.unique_id;
        *msg_flags = data.msg_flags;
@@ -306,6 +261,7 @@ bool serverid_traverse_read(int (*fn)(const struct server_id *id,
 {
        struct db_context *db;
        struct serverid_traverse_read_state state;
+       NTSTATUS status;
 
        db = serverid_db();
        if (db == NULL) {
@@ -313,7 +269,10 @@ bool serverid_traverse_read(int (*fn)(const struct server_id *id,
        }
        state.fn = fn;
        state.private_data = private_data;
-       return db->traverse_read(db, serverid_traverse_read_fn, &state);
+
+       status = dbwrap_traverse_read(db, serverid_traverse_read_fn, &state,
+                                     NULL);
+       return NT_STATUS_IS_OK(status);
 }
 
 struct serverid_traverse_state {
@@ -342,6 +301,7 @@ bool serverid_traverse(int (*fn)(struct db_record *rec,
 {
        struct db_context *db;
        struct serverid_traverse_state state;
+       NTSTATUS status;
 
        db = serverid_db();
        if (db == NULL) {
@@ -349,5 +309,102 @@ bool serverid_traverse(int (*fn)(struct db_record *rec,
        }
        state.fn = fn;
        state.private_data = private_data;
-       return db->traverse(db, serverid_traverse_fn, &state);
+
+       status = dbwrap_traverse(db, serverid_traverse_fn, &state, NULL);
+       return NT_STATUS_IS_OK(status);
+}
+
+struct msg_all {
+       struct messaging_context *msg_ctx;
+       int msg_type;
+       uint32_t msg_flag;
+       const void *buf;
+       size_t len;
+       int n_sent;
+};
+
+/****************************************************************************
+ Send one of the messages for the broadcast.
+****************************************************************************/
+
+static int traverse_fn(struct db_record *rec, const struct server_id *id,
+                      uint32_t msg_flags, void *state)
+{
+       struct msg_all *msg_all = (struct msg_all *)state;
+       NTSTATUS status;
+
+       /* Don't send if the receiver hasn't registered an interest. */
+
+       if((msg_flags & msg_all->msg_flag) == 0) {
+               return 0;
+       }
+
+       /* If the msg send fails because the pid was not found (i.e. smbd died),
+        * the msg has already been deleted from the messages.tdb.*/
+
+       status = messaging_send_buf(msg_all->msg_ctx, *id, msg_all->msg_type,
+                                   (const uint8_t *)msg_all->buf, msg_all->len);
+
+       if (NT_STATUS_EQUAL(status, NT_STATUS_INVALID_HANDLE)) {
+               struct server_id_buf idbuf;
+
+               /*
+                * If the pid was not found delete the entry from
+                * serverid.tdb
+                */
+
+               DEBUG(2, ("pid %s doesn't exist\n",
+                         server_id_str_buf(*id, &idbuf)));
+
+               dbwrap_record_delete(rec);
+       }
+       msg_all->n_sent++;
+       return 0;
+}
+
+/**
+ * Send a message to all smbd processes.
+ *
+ * It isn't very efficient, but should be OK for the sorts of
+ * applications that use it. When we need efficient broadcast we can add
+ * it.
+ *
+ * @param n_sent Set to the number of messages sent.  This should be
+ * equal to the number of processes, but be careful for races.
+ *
+ * @retval True for success.
+ **/
+bool message_send_all(struct messaging_context *msg_ctx,
+                     int msg_type,
+                     const void *buf, size_t len,
+                     int *n_sent)
+{
+       struct msg_all msg_all;
+
+       msg_all.msg_type = msg_type;
+       if (msg_type < 0x100) {
+               msg_all.msg_flag = FLAG_MSG_GENERAL;
+       } else if (msg_type > 0x100 && msg_type < 0x200) {
+               msg_all.msg_flag = FLAG_MSG_NMBD;
+       } else if (msg_type > 0x200 && msg_type < 0x300) {
+               msg_all.msg_flag = FLAG_MSG_PRINT_GENERAL;
+       } else if (msg_type > 0x300 && msg_type < 0x400) {
+               msg_all.msg_flag = FLAG_MSG_SMBD;
+       } else if (msg_type > 0x400 && msg_type < 0x600) {
+               msg_all.msg_flag = FLAG_MSG_WINBIND;
+       } else if (msg_type > 4000 && msg_type < 5000) {
+               msg_all.msg_flag = FLAG_MSG_DBWRAP;
+       } else {
+               return false;
+       }
+
+       msg_all.buf = buf;
+       msg_all.len = len;
+       msg_all.n_sent = 0;
+       msg_all.msg_ctx = msg_ctx;
+
+       serverid_traverse(traverse_fn, &msg_all);
+       if (n_sent)
+               *n_sent = msg_all.n_sent;
+       return true;
 }