Merge tag 'rust-6.9' of https://github.com/Rust-for-Linux/linux
[sfrench/cifs-2.6.git] / tools / net / ynl / lib / ynl.c
1 // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2 #include <errno.h>
3 #include <poll.h>
4 #include <string.h>
5 #include <stdlib.h>
6 #include <linux/types.h>
7
8 #include <libmnl/libmnl.h>
9 #include <linux/genetlink.h>
10
11 #include "ynl.h"
12
13 #define ARRAY_SIZE(arr)         (sizeof(arr) / sizeof(*arr))
14
15 #define __yerr_msg(yse, _msg...)                                        \
16         ({                                                              \
17                 struct ynl_error *_yse = (yse);                         \
18                                                                         \
19                 if (_yse) {                                             \
20                         snprintf(_yse->msg, sizeof(_yse->msg) - 1,  _msg); \
21                         _yse->msg[sizeof(_yse->msg) - 1] = 0;           \
22                 }                                                       \
23         })
24
25 #define __yerr_code(yse, _code...)              \
26         ({                                      \
27                 struct ynl_error *_yse = (yse); \
28                                                 \
29                 if (_yse) {                     \
30                         _yse->code = _code;     \
31                 }                               \
32         })
33
34 #define __yerr(yse, _code, _msg...)             \
35         ({                                      \
36                 __yerr_msg(yse, _msg);          \
37                 __yerr_code(yse, _code);        \
38         })
39
40 #define __perr(yse, _msg)               __yerr(yse, errno, _msg)
41
42 #define yerr_msg(_ys, _msg...)          __yerr_msg(&(_ys)->err, _msg)
43 #define yerr(_ys, _code, _msg...)       __yerr(&(_ys)->err, _code, _msg)
44 #define perr(_ys, _msg)                 __yerr(&(_ys)->err, errno, _msg)
45
46 /* -- Netlink boiler plate */
47 static int
48 ynl_err_walk_report_one(struct ynl_policy_nest *policy, unsigned int type,
49                         char *str, int str_sz, int *n)
50 {
51         if (!policy) {
52                 if (*n < str_sz)
53                         *n += snprintf(str, str_sz, "!policy");
54                 return 1;
55         }
56
57         if (type > policy->max_attr) {
58                 if (*n < str_sz)
59                         *n += snprintf(str, str_sz, "!oob");
60                 return 1;
61         }
62
63         if (!policy->table[type].name) {
64                 if (*n < str_sz)
65                         *n += snprintf(str, str_sz, "!name");
66                 return 1;
67         }
68
69         if (*n < str_sz)
70                 *n += snprintf(str, str_sz - *n,
71                                ".%s", policy->table[type].name);
72         return 0;
73 }
74
75 static int
76 ynl_err_walk(struct ynl_sock *ys, void *start, void *end, unsigned int off,
77              struct ynl_policy_nest *policy, char *str, int str_sz,
78              struct ynl_policy_nest **nest_pol)
79 {
80         unsigned int astart_off, aend_off;
81         const struct nlattr *attr;
82         unsigned int data_len;
83         unsigned int type;
84         bool found = false;
85         int n = 0;
86
87         if (!policy) {
88                 if (n < str_sz)
89                         n += snprintf(str, str_sz, "!policy");
90                 return n;
91         }
92
93         data_len = end - start;
94
95         mnl_attr_for_each_payload(start, data_len) {
96                 astart_off = (char *)attr - (char *)start;
97                 aend_off = astart_off + mnl_attr_get_payload_len(attr);
98                 if (aend_off <= off)
99                         continue;
100
101                 found = true;
102                 break;
103         }
104         if (!found)
105                 return 0;
106
107         off -= astart_off;
108
109         type = mnl_attr_get_type(attr);
110
111         if (ynl_err_walk_report_one(policy, type, str, str_sz, &n))
112                 return n;
113
114         if (!off) {
115                 if (nest_pol)
116                         *nest_pol = policy->table[type].nest;
117                 return n;
118         }
119
120         if (!policy->table[type].nest) {
121                 if (n < str_sz)
122                         n += snprintf(str, str_sz, "!nest");
123                 return n;
124         }
125
126         off -= sizeof(struct nlattr);
127         start =  mnl_attr_get_payload(attr);
128         end = start + mnl_attr_get_payload_len(attr);
129
130         return n + ynl_err_walk(ys, start, end, off, policy->table[type].nest,
131                                 &str[n], str_sz - n, nest_pol);
132 }
133
134 #define NLMSGERR_ATTR_MISS_TYPE (NLMSGERR_ATTR_POLICY + 1)
135 #define NLMSGERR_ATTR_MISS_NEST (NLMSGERR_ATTR_POLICY + 2)
136 #define NLMSGERR_ATTR_MAX (NLMSGERR_ATTR_MAX + 2)
137
138 static int
139 ynl_ext_ack_check(struct ynl_sock *ys, const struct nlmsghdr *nlh,
140                   unsigned int hlen)
141 {
142         const struct nlattr *tb[NLMSGERR_ATTR_MAX + 1] = {};
143         char miss_attr[sizeof(ys->err.msg)];
144         char bad_attr[sizeof(ys->err.msg)];
145         const struct nlattr *attr;
146         const char *str = NULL;
147
148         if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS)) {
149                 yerr_msg(ys, "%s", strerror(ys->err.code));
150                 return MNL_CB_OK;
151         }
152
153         mnl_attr_for_each(attr, nlh, hlen) {
154                 unsigned int len, type;
155
156                 len = mnl_attr_get_payload_len(attr);
157                 type = mnl_attr_get_type(attr);
158
159                 if (type > NLMSGERR_ATTR_MAX)
160                         continue;
161
162                 tb[type] = attr;
163
164                 switch (type) {
165                 case NLMSGERR_ATTR_OFFS:
166                 case NLMSGERR_ATTR_MISS_TYPE:
167                 case NLMSGERR_ATTR_MISS_NEST:
168                         if (len != sizeof(__u32))
169                                 return MNL_CB_ERROR;
170                         break;
171                 case NLMSGERR_ATTR_MSG:
172                         str = mnl_attr_get_payload(attr);
173                         if (str[len - 1])
174                                 return MNL_CB_ERROR;
175                         break;
176                 default:
177                         break;
178                 }
179         }
180
181         bad_attr[0] = '\0';
182         miss_attr[0] = '\0';
183
184         if (tb[NLMSGERR_ATTR_OFFS]) {
185                 unsigned int n, off;
186                 void *start, *end;
187
188                 ys->err.attr_offs = mnl_attr_get_u32(tb[NLMSGERR_ATTR_OFFS]);
189
190                 n = snprintf(bad_attr, sizeof(bad_attr), "%sbad attribute: ",
191                              str ? " (" : "");
192
193                 start = mnl_nlmsg_get_payload_offset(ys->nlh,
194                                                      ys->family->hdr_len);
195                 end = mnl_nlmsg_get_payload_tail(ys->nlh);
196
197                 off = ys->err.attr_offs;
198                 off -= sizeof(struct nlmsghdr);
199                 off -= ys->family->hdr_len;
200
201                 n += ynl_err_walk(ys, start, end, off, ys->req_policy,
202                                   &bad_attr[n], sizeof(bad_attr) - n, NULL);
203
204                 if (n >= sizeof(bad_attr))
205                         n = sizeof(bad_attr) - 1;
206                 bad_attr[n] = '\0';
207         }
208         if (tb[NLMSGERR_ATTR_MISS_TYPE]) {
209                 struct ynl_policy_nest *nest_pol = NULL;
210                 unsigned int n, off, type;
211                 void *start, *end;
212                 int n2;
213
214                 type = mnl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_TYPE]);
215
216                 n = snprintf(miss_attr, sizeof(miss_attr), "%smissing attribute: ",
217                              bad_attr[0] ? ", " : (str ? " (" : ""));
218
219                 start = mnl_nlmsg_get_payload_offset(ys->nlh,
220                                                      ys->family->hdr_len);
221                 end = mnl_nlmsg_get_payload_tail(ys->nlh);
222
223                 nest_pol = ys->req_policy;
224                 if (tb[NLMSGERR_ATTR_MISS_NEST]) {
225                         off = mnl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_NEST]);
226                         off -= sizeof(struct nlmsghdr);
227                         off -= ys->family->hdr_len;
228
229                         n += ynl_err_walk(ys, start, end, off, ys->req_policy,
230                                           &miss_attr[n], sizeof(miss_attr) - n,
231                                           &nest_pol);
232                 }
233
234                 n2 = 0;
235                 ynl_err_walk_report_one(nest_pol, type, &miss_attr[n],
236                                         sizeof(miss_attr) - n, &n2);
237                 n += n2;
238
239                 if (n >= sizeof(miss_attr))
240                         n = sizeof(miss_attr) - 1;
241                 miss_attr[n] = '\0';
242         }
243
244         /* Implicitly depend on ys->err.code already set */
245         if (str)
246                 yerr_msg(ys, "Kernel %s: '%s'%s%s%s",
247                          ys->err.code ? "error" : "warning",
248                          str, bad_attr, miss_attr,
249                          bad_attr[0] || miss_attr[0] ? ")" : "");
250         else if (bad_attr[0] || miss_attr[0])
251                 yerr_msg(ys, "Kernel %s: %s%s",
252                          ys->err.code ? "error" : "warning",
253                          bad_attr, miss_attr);
254         else
255                 yerr_msg(ys, "%s", strerror(ys->err.code));
256
257         return MNL_CB_OK;
258 }
259
260 static int ynl_cb_error(const struct nlmsghdr *nlh, void *data)
261 {
262         const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh);
263         struct ynl_parse_arg *yarg = data;
264         unsigned int hlen;
265         int code;
266
267         code = err->error >= 0 ? err->error : -err->error;
268         yarg->ys->err.code = code;
269         errno = code;
270
271         hlen = sizeof(*err);
272         if (!(nlh->nlmsg_flags & NLM_F_CAPPED))
273                 hlen += mnl_nlmsg_get_payload_len(&err->msg);
274
275         ynl_ext_ack_check(yarg->ys, nlh, hlen);
276
277         return code ? MNL_CB_ERROR : MNL_CB_STOP;
278 }
279
280 static int ynl_cb_done(const struct nlmsghdr *nlh, void *data)
281 {
282         struct ynl_parse_arg *yarg = data;
283         int err;
284
285         err = *(int *)NLMSG_DATA(nlh);
286         if (err < 0) {
287                 yarg->ys->err.code = -err;
288                 errno = -err;
289
290                 ynl_ext_ack_check(yarg->ys, nlh, sizeof(int));
291
292                 return MNL_CB_ERROR;
293         }
294         return MNL_CB_STOP;
295 }
296
297 static int ynl_cb_noop(const struct nlmsghdr *nlh, void *data)
298 {
299         return MNL_CB_OK;
300 }
301
302 mnl_cb_t ynl_cb_array[NLMSG_MIN_TYPE] = {
303         [NLMSG_NOOP]    = ynl_cb_noop,
304         [NLMSG_ERROR]   = ynl_cb_error,
305         [NLMSG_DONE]    = ynl_cb_done,
306         [NLMSG_OVERRUN] = ynl_cb_noop,
307 };
308
309 /* Attribute validation */
310
311 int ynl_attr_validate(struct ynl_parse_arg *yarg, const struct nlattr *attr)
312 {
313         struct ynl_policy_attr *policy;
314         unsigned int type, len;
315         unsigned char *data;
316
317         data = mnl_attr_get_payload(attr);
318         len = mnl_attr_get_payload_len(attr);
319         type = mnl_attr_get_type(attr);
320         if (type > yarg->rsp_policy->max_attr) {
321                 yerr(yarg->ys, YNL_ERROR_INTERNAL,
322                      "Internal error, validating unknown attribute");
323                 return -1;
324         }
325
326         policy = &yarg->rsp_policy->table[type];
327
328         switch (policy->type) {
329         case YNL_PT_REJECT:
330                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
331                      "Rejected attribute (%s)", policy->name);
332                 return -1;
333         case YNL_PT_IGNORE:
334                 break;
335         case YNL_PT_U8:
336                 if (len == sizeof(__u8))
337                         break;
338                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
339                      "Invalid attribute (u8 %s)", policy->name);
340                 return -1;
341         case YNL_PT_U16:
342                 if (len == sizeof(__u16))
343                         break;
344                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
345                      "Invalid attribute (u16 %s)", policy->name);
346                 return -1;
347         case YNL_PT_U32:
348                 if (len == sizeof(__u32))
349                         break;
350                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
351                      "Invalid attribute (u32 %s)", policy->name);
352                 return -1;
353         case YNL_PT_U64:
354                 if (len == sizeof(__u64))
355                         break;
356                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
357                      "Invalid attribute (u64 %s)", policy->name);
358                 return -1;
359         case YNL_PT_UINT:
360                 if (len == sizeof(__u32) || len == sizeof(__u64))
361                         break;
362                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
363                      "Invalid attribute (uint %s)", policy->name);
364                 return -1;
365         case YNL_PT_FLAG:
366                 /* Let flags grow into real attrs, why not.. */
367                 break;
368         case YNL_PT_NEST:
369                 if (!len || len >= sizeof(*attr))
370                         break;
371                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
372                      "Invalid attribute (nest %s)", policy->name);
373                 return -1;
374         case YNL_PT_BINARY:
375                 if (!policy->len || len == policy->len)
376                         break;
377                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
378                      "Invalid attribute (binary %s)", policy->name);
379                 return -1;
380         case YNL_PT_NUL_STR:
381                 if ((!policy->len || len <= policy->len) && !data[len - 1])
382                         break;
383                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
384                      "Invalid attribute (string %s)", policy->name);
385                 return -1;
386         case YNL_PT_BITFIELD32:
387                 if (len == sizeof(struct nla_bitfield32))
388                         break;
389                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
390                      "Invalid attribute (bitfield32 %s)", policy->name);
391                 return -1;
392         default:
393                 yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
394                      "Invalid attribute (unknown %s)", policy->name);
395                 return -1;
396         }
397
398         return 0;
399 }
400
401 /* Generic code */
402
403 static void ynl_err_reset(struct ynl_sock *ys)
404 {
405         ys->err.code = 0;
406         ys->err.attr_offs = 0;
407         ys->err.msg[0] = 0;
408 }
409
410 struct nlmsghdr *ynl_msg_start(struct ynl_sock *ys, __u32 id, __u16 flags)
411 {
412         struct nlmsghdr *nlh;
413
414         ynl_err_reset(ys);
415
416         nlh = ys->nlh = mnl_nlmsg_put_header(ys->tx_buf);
417         nlh->nlmsg_type = id;
418         nlh->nlmsg_flags = flags;
419         nlh->nlmsg_seq = ++ys->seq;
420
421         return nlh;
422 }
423
424 struct nlmsghdr *
425 ynl_gemsg_start(struct ynl_sock *ys, __u32 id, __u16 flags,
426                 __u8 cmd, __u8 version)
427 {
428         struct genlmsghdr gehdr;
429         struct nlmsghdr *nlh;
430         void *data;
431
432         nlh = ynl_msg_start(ys, id, flags);
433
434         memset(&gehdr, 0, sizeof(gehdr));
435         gehdr.cmd = cmd;
436         gehdr.version = version;
437
438         data = mnl_nlmsg_put_extra_header(nlh, sizeof(gehdr));
439         memcpy(data, &gehdr, sizeof(gehdr));
440
441         return nlh;
442 }
443
444 void ynl_msg_start_req(struct ynl_sock *ys, __u32 id)
445 {
446         ynl_msg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK);
447 }
448
449 void ynl_msg_start_dump(struct ynl_sock *ys, __u32 id)
450 {
451         ynl_msg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP);
452 }
453
454 struct nlmsghdr *
455 ynl_gemsg_start_req(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version)
456 {
457         return ynl_gemsg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK, cmd, version);
458 }
459
460 struct nlmsghdr *
461 ynl_gemsg_start_dump(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version)
462 {
463         return ynl_gemsg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP,
464                                cmd, version);
465 }
466
467 int ynl_recv_ack(struct ynl_sock *ys, int ret)
468 {
469         struct ynl_parse_arg yarg = { .ys = ys, };
470
471         if (!ret) {
472                 yerr(ys, YNL_ERROR_EXPECT_ACK,
473                      "Expecting an ACK but nothing received");
474                 return -1;
475         }
476
477         ret = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);
478         if (ret < 0) {
479                 perr(ys, "Socket receive failed");
480                 return ret;
481         }
482         return mnl_cb_run(ys->rx_buf, ret, ys->seq, ys->portid,
483                           ynl_cb_null, &yarg);
484 }
485
486 int ynl_cb_null(const struct nlmsghdr *nlh, void *data)
487 {
488         struct ynl_parse_arg *yarg = data;
489
490         yerr(yarg->ys, YNL_ERROR_UNEXPECT_MSG,
491              "Received a message when none were expected");
492
493         return MNL_CB_ERROR;
494 }
495
496 /* Init/fini and genetlink boiler plate */
497 static int
498 ynl_get_family_info_mcast(struct ynl_sock *ys, const struct nlattr *mcasts)
499 {
500         const struct nlattr *entry, *attr;
501         unsigned int i;
502
503         mnl_attr_for_each_nested(attr, mcasts)
504                 ys->n_mcast_groups++;
505
506         if (!ys->n_mcast_groups)
507                 return 0;
508
509         ys->mcast_groups = calloc(ys->n_mcast_groups,
510                                   sizeof(*ys->mcast_groups));
511         if (!ys->mcast_groups)
512                 return MNL_CB_ERROR;
513
514         i = 0;
515         mnl_attr_for_each_nested(entry, mcasts) {
516                 mnl_attr_for_each_nested(attr, entry) {
517                         if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GRP_ID)
518                                 ys->mcast_groups[i].id = mnl_attr_get_u32(attr);
519                         if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GRP_NAME) {
520                                 strncpy(ys->mcast_groups[i].name,
521                                         mnl_attr_get_str(attr),
522                                         GENL_NAMSIZ - 1);
523                                 ys->mcast_groups[i].name[GENL_NAMSIZ - 1] = 0;
524                         }
525                 }
526                 i++;
527         }
528
529         return 0;
530 }
531
532 static int ynl_get_family_info_cb(const struct nlmsghdr *nlh, void *data)
533 {
534         struct ynl_parse_arg *yarg = data;
535         struct ynl_sock *ys = yarg->ys;
536         const struct nlattr *attr;
537         bool found_id = true;
538
539         mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr)) {
540                 if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GROUPS)
541                         if (ynl_get_family_info_mcast(ys, attr))
542                                 return MNL_CB_ERROR;
543
544                 if (mnl_attr_get_type(attr) != CTRL_ATTR_FAMILY_ID)
545                         continue;
546
547                 if (mnl_attr_get_payload_len(attr) != sizeof(__u16)) {
548                         yerr(ys, YNL_ERROR_ATTR_INVALID, "Invalid family ID");
549                         return MNL_CB_ERROR;
550                 }
551
552                 ys->family_id = mnl_attr_get_u16(attr);
553                 found_id = true;
554         }
555
556         if (!found_id) {
557                 yerr(ys, YNL_ERROR_ATTR_MISSING, "Family ID missing");
558                 return MNL_CB_ERROR;
559         }
560         return MNL_CB_OK;
561 }
562
563 static int ynl_sock_read_family(struct ynl_sock *ys, const char *family_name)
564 {
565         struct ynl_parse_arg yarg = { .ys = ys, };
566         struct nlmsghdr *nlh;
567         int err;
568
569         nlh = ynl_gemsg_start_req(ys, GENL_ID_CTRL, CTRL_CMD_GETFAMILY, 1);
570         mnl_attr_put_strz(nlh, CTRL_ATTR_FAMILY_NAME, family_name);
571
572         err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);
573         if (err < 0) {
574                 perr(ys, "failed to request socket family info");
575                 return err;
576         }
577
578         err = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);
579         if (err <= 0) {
580                 perr(ys, "failed to receive the socket family info");
581                 return err;
582         }
583         err = mnl_cb_run2(ys->rx_buf, err, ys->seq, ys->portid,
584                           ynl_get_family_info_cb, &yarg,
585                           ynl_cb_array, ARRAY_SIZE(ynl_cb_array));
586         if (err < 0) {
587                 free(ys->mcast_groups);
588                 perr(ys, "failed to receive the socket family info - no such family?");
589                 return err;
590         }
591
592         err = ynl_recv_ack(ys, err);
593         if (err < 0) {
594                 free(ys->mcast_groups);
595                 return err;
596         }
597
598         return 0;
599 }
600
601 struct ynl_sock *
602 ynl_sock_create(const struct ynl_family *yf, struct ynl_error *yse)
603 {
604         struct ynl_sock *ys;
605         int one = 1;
606
607         ys = malloc(sizeof(*ys) + 2 * MNL_SOCKET_BUFFER_SIZE);
608         if (!ys)
609                 return NULL;
610         memset(ys, 0, sizeof(*ys));
611
612         ys->family = yf;
613         ys->tx_buf = &ys->raw_buf[0];
614         ys->rx_buf = &ys->raw_buf[MNL_SOCKET_BUFFER_SIZE];
615         ys->ntf_last_next = &ys->ntf_first;
616
617         ys->sock = mnl_socket_open(NETLINK_GENERIC);
618         if (!ys->sock) {
619                 __perr(yse, "failed to create a netlink socket");
620                 goto err_free_sock;
621         }
622
623         if (mnl_socket_setsockopt(ys->sock, NETLINK_CAP_ACK,
624                                   &one, sizeof(one))) {
625                 __perr(yse, "failed to enable netlink ACK");
626                 goto err_close_sock;
627         }
628         if (mnl_socket_setsockopt(ys->sock, NETLINK_EXT_ACK,
629                                   &one, sizeof(one))) {
630                 __perr(yse, "failed to enable netlink ext ACK");
631                 goto err_close_sock;
632         }
633
634         ys->seq = random();
635         ys->portid = mnl_socket_get_portid(ys->sock);
636
637         if (ynl_sock_read_family(ys, yf->name)) {
638                 if (yse)
639                         memcpy(yse, &ys->err, sizeof(*yse));
640                 goto err_close_sock;
641         }
642
643         return ys;
644
645 err_close_sock:
646         mnl_socket_close(ys->sock);
647 err_free_sock:
648         free(ys);
649         return NULL;
650 }
651
652 void ynl_sock_destroy(struct ynl_sock *ys)
653 {
654         struct ynl_ntf_base_type *ntf;
655
656         mnl_socket_close(ys->sock);
657         while ((ntf = ynl_ntf_dequeue(ys)))
658                 ynl_ntf_free(ntf);
659         free(ys->mcast_groups);
660         free(ys);
661 }
662
663 /* YNL multicast handling */
664
665 void ynl_ntf_free(struct ynl_ntf_base_type *ntf)
666 {
667         ntf->free(ntf);
668 }
669
670 int ynl_subscribe(struct ynl_sock *ys, const char *grp_name)
671 {
672         unsigned int i;
673         int err;
674
675         for (i = 0; i < ys->n_mcast_groups; i++)
676                 if (!strcmp(ys->mcast_groups[i].name, grp_name))
677                         break;
678         if (i == ys->n_mcast_groups) {
679                 yerr(ys, ENOENT, "Multicast group '%s' not found", grp_name);
680                 return -1;
681         }
682
683         err = mnl_socket_setsockopt(ys->sock, NETLINK_ADD_MEMBERSHIP,
684                                     &ys->mcast_groups[i].id,
685                                     sizeof(ys->mcast_groups[i].id));
686         if (err < 0) {
687                 perr(ys, "Subscribing to multicast group failed");
688                 return -1;
689         }
690
691         return 0;
692 }
693
694 int ynl_socket_get_fd(struct ynl_sock *ys)
695 {
696         return mnl_socket_get_fd(ys->sock);
697 }
698
699 struct ynl_ntf_base_type *ynl_ntf_dequeue(struct ynl_sock *ys)
700 {
701         struct ynl_ntf_base_type *ntf;
702
703         if (!ynl_has_ntf(ys))
704                 return NULL;
705
706         ntf = ys->ntf_first;
707         ys->ntf_first = ntf->next;
708         if (ys->ntf_last_next == &ntf->next)
709                 ys->ntf_last_next = &ys->ntf_first;
710
711         return ntf;
712 }
713
714 static int ynl_ntf_parse(struct ynl_sock *ys, const struct nlmsghdr *nlh)
715 {
716         struct ynl_parse_arg yarg = { .ys = ys, };
717         const struct ynl_ntf_info *info;
718         struct ynl_ntf_base_type *rsp;
719         struct genlmsghdr *gehdr;
720         int ret;
721
722         gehdr = mnl_nlmsg_get_payload(nlh);
723         if (gehdr->cmd >= ys->family->ntf_info_size)
724                 return MNL_CB_ERROR;
725         info = &ys->family->ntf_info[gehdr->cmd];
726         if (!info->cb)
727                 return MNL_CB_ERROR;
728
729         rsp = calloc(1, info->alloc_sz);
730         rsp->free = info->free;
731         yarg.data = rsp->data;
732         yarg.rsp_policy = info->policy;
733
734         ret = info->cb(nlh, &yarg);
735         if (ret <= MNL_CB_STOP)
736                 goto err_free;
737
738         rsp->family = nlh->nlmsg_type;
739         rsp->cmd = gehdr->cmd;
740
741         *ys->ntf_last_next = rsp;
742         ys->ntf_last_next = &rsp->next;
743
744         return MNL_CB_OK;
745
746 err_free:
747         info->free(rsp);
748         return MNL_CB_ERROR;
749 }
750
751 static int ynl_ntf_trampoline(const struct nlmsghdr *nlh, void *data)
752 {
753         struct ynl_parse_arg *yarg = data;
754
755         return ynl_ntf_parse(yarg->ys, nlh);
756 }
757
758 int ynl_ntf_check(struct ynl_sock *ys)
759 {
760         struct ynl_parse_arg yarg = { .ys = ys, };
761         ssize_t len;
762         int err;
763
764         do {
765                 /* libmnl doesn't let us pass flags to the recv to make
766                  * it non-blocking so we need to poll() or peek() :|
767                  */
768                 struct pollfd pfd = { };
769
770                 pfd.fd = mnl_socket_get_fd(ys->sock);
771                 pfd.events = POLLIN;
772                 err = poll(&pfd, 1, 1);
773                 if (err < 1)
774                         return err;
775
776                 len = mnl_socket_recvfrom(ys->sock, ys->rx_buf,
777                                           MNL_SOCKET_BUFFER_SIZE);
778                 if (len < 0)
779                         return len;
780
781                 err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,
782                                   ynl_ntf_trampoline, &yarg,
783                                   ynl_cb_array, NLMSG_MIN_TYPE);
784                 if (err < 0)
785                         return err;
786         } while (err > 0);
787
788         return 0;
789 }
790
791 /* YNL specific helpers used by the auto-generated code */
792
793 struct ynl_dump_list_type *YNL_LIST_END = (void *)(0xb4d123);
794
795 void ynl_error_unknown_notification(struct ynl_sock *ys, __u8 cmd)
796 {
797         yerr(ys, YNL_ERROR_UNKNOWN_NTF,
798              "Unknown notification message type '%d'", cmd);
799 }
800
801 int ynl_error_parse(struct ynl_parse_arg *yarg, const char *msg)
802 {
803         yerr(yarg->ys, YNL_ERROR_INV_RESP, "Error parsing response: %s", msg);
804         return MNL_CB_ERROR;
805 }
806
807 static int
808 ynl_check_alien(struct ynl_sock *ys, const struct nlmsghdr *nlh, __u32 rsp_cmd)
809 {
810         struct genlmsghdr *gehdr;
811
812         if (mnl_nlmsg_get_payload_len(nlh) < sizeof(*gehdr)) {
813                 yerr(ys, YNL_ERROR_INV_RESP,
814                      "Kernel responded with truncated message");
815                 return -1;
816         }
817
818         gehdr = mnl_nlmsg_get_payload(nlh);
819         if (gehdr->cmd != rsp_cmd)
820                 return ynl_ntf_parse(ys, nlh);
821
822         return 0;
823 }
824
825 static int ynl_req_trampoline(const struct nlmsghdr *nlh, void *data)
826 {
827         struct ynl_req_state *yrs = data;
828         int ret;
829
830         ret = ynl_check_alien(yrs->yarg.ys, nlh, yrs->rsp_cmd);
831         if (ret)
832                 return ret < 0 ? MNL_CB_ERROR : MNL_CB_OK;
833
834         return yrs->cb(nlh, &yrs->yarg);
835 }
836
837 int ynl_exec(struct ynl_sock *ys, struct nlmsghdr *req_nlh,
838              struct ynl_req_state *yrs)
839 {
840         ssize_t len;
841         int err;
842
843         err = mnl_socket_sendto(ys->sock, req_nlh, req_nlh->nlmsg_len);
844         if (err < 0)
845                 return err;
846
847         do {
848                 len = mnl_socket_recvfrom(ys->sock, ys->rx_buf,
849                                           MNL_SOCKET_BUFFER_SIZE);
850                 if (len < 0)
851                         return len;
852
853                 err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,
854                                   ynl_req_trampoline, yrs,
855                                   ynl_cb_array, NLMSG_MIN_TYPE);
856                 if (err < 0)
857                         return err;
858         } while (err > 0);
859
860         return 0;
861 }
862
863 static int ynl_dump_trampoline(const struct nlmsghdr *nlh, void *data)
864 {
865         struct ynl_dump_state *ds = data;
866         struct ynl_dump_list_type *obj;
867         struct ynl_parse_arg yarg = {};
868         int ret;
869
870         ret = ynl_check_alien(ds->ys, nlh, ds->rsp_cmd);
871         if (ret)
872                 return ret < 0 ? MNL_CB_ERROR : MNL_CB_OK;
873
874         obj = calloc(1, ds->alloc_sz);
875         if (!obj)
876                 return MNL_CB_ERROR;
877
878         if (!ds->first)
879                 ds->first = obj;
880         if (ds->last)
881                 ds->last->next = obj;
882         ds->last = obj;
883
884         yarg.ys = ds->ys;
885         yarg.rsp_policy = ds->rsp_policy;
886         yarg.data = &obj->data;
887
888         return ds->cb(nlh, &yarg);
889 }
890
891 static void *ynl_dump_end(struct ynl_dump_state *ds)
892 {
893         if (!ds->first)
894                 return YNL_LIST_END;
895
896         ds->last->next = YNL_LIST_END;
897         return ds->first;
898 }
899
900 int ynl_exec_dump(struct ynl_sock *ys, struct nlmsghdr *req_nlh,
901                   struct ynl_dump_state *yds)
902 {
903         ssize_t len;
904         int err;
905
906         err = mnl_socket_sendto(ys->sock, req_nlh, req_nlh->nlmsg_len);
907         if (err < 0)
908                 return err;
909
910         do {
911                 len = mnl_socket_recvfrom(ys->sock, ys->rx_buf,
912                                           MNL_SOCKET_BUFFER_SIZE);
913                 if (len < 0)
914                         goto err_close_list;
915
916                 err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,
917                                   ynl_dump_trampoline, yds,
918                                   ynl_cb_array, NLMSG_MIN_TYPE);
919                 if (err < 0)
920                         goto err_close_list;
921         } while (err > 0);
922
923         yds->first = ynl_dump_end(yds);
924         return 0;
925
926 err_close_list:
927         yds->first = ynl_dump_end(yds);
928         return -1;
929 }