ctdb-transport: Add transport api
[amitay/samba.git] / ctdb / transport / transport_api.c
1 #include "replace.h"
2 #include "system/network.h"
3
4 #include <talloc.h>
5 #include <tevent.h>
6
7 #include "lib/util/debug.h"
8 #include "lib/util/blocking.h"
9 #include "lib/util/tevent_unix.h"
10 #include "lib/util/byteorder.h"
11 #include "lib/messaging/messages_dgm.h"
12
13 #include "common/path.h"
14 #include "common/pkt_read.h"
15 #include "common/pkt_write.h"
16
17 #include "transport/transport.h"
18 #include "transport/transport_header.h"
19 #include "transport/transport_packet.h"
20 #include "transport/transport_db.h"
21 #include "transport/transport_api.h"
22
23 struct transport_context {
24         struct transport_db_context *db;
25         uint64_t unique;
26         uint32_t endpoint_id;
27
28         void (*disconnect_cb)(void *private_data);
29         void (*read_cb)(struct transport_endpoint *src,
30                         struct transport_packet *pkt,
31                         void *private_data);
32         void *private_data;
33 };
34
35 static int transport_context_destructor(struct transport_context *transport);
36 static void transport_read(struct tevent_context *ev,
37                            const uint8_t *msg,
38                            size_t msg_len,
39                            int *fds,
40                            size_t num_fds,
41                            void *private_data);
42
43 struct transport_context *transport_init(
44                 TALLOC_CTX *mem_ctx,
45                 struct tevent_context *ev,
46                 uint32_t endpoint_id,
47                 void (*disconnect_cb)(void *private_data),
48                 void (*read_cb)(struct transport_endpoint *src,
49                                 struct transport_packet *pkt,
50                                 void *private_data),
51                 void *private_data)
52 {
53         struct transport_context *transport;
54         int ret;
55
56         transport = talloc(mem_ctx, struct transport_context);
57         if (transport == NULL) {
58                 return NULL;
59         }
60
61         transport->db = transport_db_register(transport, &endpoint_id);
62         if (transport->db == NULL) {
63                 goto fail;
64         }
65         transport->endpoint_id = endpoint_id;
66
67         transport->disconnect_cb = disconnect_cb;
68         transport->read_cb = read_cb;
69         transport->private_data = private_data;
70
71         ret = messaging_dgm_init(ev,
72                                  &transport->unique,
73                                  path_rundir(),
74                                  path_rundir(),
75                                  transport_read,
76                                  transport);
77         if (ret != 0) {
78                 D_ERR("transport: Failed to initialize messaging\n");
79                 goto fail;
80         }
81
82         talloc_set_destructor(transport, transport_context_destructor);
83
84         return 0;
85
86 fail:
87         talloc_free(transport);
88         return NULL;
89 }
90
91 static int transport_context_destructor(struct transport_context *transport)
92 {
93         messaging_dgm_destroy();
94
95         return 0;
96 }
97
98 static void transport_read(struct tevent_context *ev,
99                            const uint8_t *msg,
100                            size_t msg_len,
101                            int *fds,
102                            size_t num_fds,
103                            void *private_data)
104 {
105         struct transport_context *transport = talloc_get_type_abort(
106                 private_data, struct transport_context);
107         struct transport_packet *pkt;
108         struct transport_header header;
109         int ret;
110         bool ok;
111
112         pkt = transport_packet_init(transport, msg, msg_len);
113         if (pkt == NULL) {
114                 D_ERR("transport: Dropping packet\n");
115                 return;
116         }
117
118         ret = transport_header_pull(pkt, &header);
119         if (ret != 0) {
120                 D_ERR("transport: Invalid packet, dropping\n");
121                 goto done;
122         }
123
124         ok = transport_header_verify(&header);
125         if (!ok) {
126                 D_ERR("transport: Invalid header, dropping\n");
127                 goto done;
128         }
129
130         if (header.dst.endpoint != transport->endpoint_id) {
131                 D_ERR("transport: Wrong destination 0x%x, dropping\n",
132                       header.dst.endpoint);
133                 goto done;
134         }
135
136         transport->read_cb(&header.src, pkt, transport->private_data);
137
138 done:
139         talloc_free(pkt);
140 }
141
142 int transport_write(struct transport_context *transport,
143                     struct transport_endpoint *dst,
144                     struct transport_packet *pkt)
145 {
146         struct iovec iov[2];
147         struct transport_header header;
148         struct transport_endpoint src;
149         uint8_t *buf = NULL;
150         size_t buflen = 0;
151         pid_t pid;
152         int ret;
153
154         pid = transport_db_lookup(transport->db, dst->endpoint);
155         if (pid == -1) {
156                 D_ERR("transport: Unknown endpoint id 0x%x\n",
157                       dst->endpoint);
158                 return EINVAL;
159         }
160
161         src.node = CTDB_NODE_LOCAL;
162         src.endpoint = transport->endpoint_id;
163
164         transport_header_fill(&header, &src, dst);
165
166         ret = transport_packet_finish(pkt, &buf, &buflen);
167         if (ret != 0) {
168                 return EINVAL;
169         }
170
171         iov[0].iov_base = &header;
172         iov[0].iov_len = transport_header_len(&header);
173
174         iov[1].iov_base = buf;
175         iov[1].iov_len = buflen;
176
177         ret = messaging_dgm_send(pid, iov, 2, NULL, 0);
178         if (ret != 0) {
179                 D_ERR("transport: Failed to send message\n");
180                 return ret;
181         }
182
183         return 0;
184 }