Merge branch 'linux-4.21' of git://github.com/skeggsb/linux into drm-fixes
[sfrench/cifs-2.6.git] / tools / testing / selftests / bpf / test_sock_addr.c
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3
4 #define _GNU_SOURCE
5
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <unistd.h>
9
10 #include <arpa/inet.h>
11 #include <netinet/in.h>
12 #include <sys/types.h>
13 #include <sys/select.h>
14 #include <sys/socket.h>
15
16 #include <linux/filter.h>
17
18 #include <bpf/bpf.h>
19 #include <bpf/libbpf.h>
20
21 #include "cgroup_helpers.h"
22 #include "bpf_rlimit.h"
23 #include "bpf_util.h"
24
25 #ifndef ENOTSUPP
26 # define ENOTSUPP 524
27 #endif
28
29 #define CG_PATH "/foo"
30 #define CONNECT4_PROG_PATH      "./connect4_prog.o"
31 #define CONNECT6_PROG_PATH      "./connect6_prog.o"
32 #define SENDMSG4_PROG_PATH      "./sendmsg4_prog.o"
33 #define SENDMSG6_PROG_PATH      "./sendmsg6_prog.o"
34
35 #define SERV4_IP                "192.168.1.254"
36 #define SERV4_REWRITE_IP        "127.0.0.1"
37 #define SRC4_IP                 "172.16.0.1"
38 #define SRC4_REWRITE_IP         "127.0.0.4"
39 #define SERV4_PORT              4040
40 #define SERV4_REWRITE_PORT      4444
41
42 #define SERV6_IP                "face:b00c:1234:5678::abcd"
43 #define SERV6_REWRITE_IP        "::1"
44 #define SERV6_V4MAPPED_IP       "::ffff:192.168.0.4"
45 #define SRC6_IP                 "::1"
46 #define SRC6_REWRITE_IP         "::6"
47 #define SERV6_PORT              6060
48 #define SERV6_REWRITE_PORT      6666
49
50 #define INET_NTOP_BUF   40
51
52 struct sock_addr_test;
53
54 typedef int (*load_fn)(const struct sock_addr_test *test);
55 typedef int (*info_fn)(int, struct sockaddr *, socklen_t *);
56
57 char bpf_log_buf[BPF_LOG_BUF_SIZE];
58
59 struct sock_addr_test {
60         const char *descr;
61         /* BPF prog properties */
62         load_fn loadfn;
63         enum bpf_attach_type expected_attach_type;
64         enum bpf_attach_type attach_type;
65         /* Socket properties */
66         int domain;
67         int type;
68         /* IP:port pairs for BPF prog to override */
69         const char *requested_ip;
70         unsigned short requested_port;
71         const char *expected_ip;
72         unsigned short expected_port;
73         const char *expected_src_ip;
74         /* Expected test result */
75         enum {
76                 LOAD_REJECT,
77                 ATTACH_REJECT,
78                 SYSCALL_EPERM,
79                 SYSCALL_ENOTSUPP,
80                 SUCCESS,
81         } expected_result;
82 };
83
84 static int bind4_prog_load(const struct sock_addr_test *test);
85 static int bind6_prog_load(const struct sock_addr_test *test);
86 static int connect4_prog_load(const struct sock_addr_test *test);
87 static int connect6_prog_load(const struct sock_addr_test *test);
88 static int sendmsg_deny_prog_load(const struct sock_addr_test *test);
89 static int sendmsg4_rw_asm_prog_load(const struct sock_addr_test *test);
90 static int sendmsg4_rw_c_prog_load(const struct sock_addr_test *test);
91 static int sendmsg6_rw_asm_prog_load(const struct sock_addr_test *test);
92 static int sendmsg6_rw_c_prog_load(const struct sock_addr_test *test);
93 static int sendmsg6_rw_v4mapped_prog_load(const struct sock_addr_test *test);
94
95 static struct sock_addr_test tests[] = {
96         /* bind */
97         {
98                 "bind4: load prog with wrong expected attach type",
99                 bind4_prog_load,
100                 BPF_CGROUP_INET6_BIND,
101                 BPF_CGROUP_INET4_BIND,
102                 AF_INET,
103                 SOCK_STREAM,
104                 NULL,
105                 0,
106                 NULL,
107                 0,
108                 NULL,
109                 LOAD_REJECT,
110         },
111         {
112                 "bind4: attach prog with wrong attach type",
113                 bind4_prog_load,
114                 BPF_CGROUP_INET4_BIND,
115                 BPF_CGROUP_INET6_BIND,
116                 AF_INET,
117                 SOCK_STREAM,
118                 NULL,
119                 0,
120                 NULL,
121                 0,
122                 NULL,
123                 ATTACH_REJECT,
124         },
125         {
126                 "bind4: rewrite IP & TCP port in",
127                 bind4_prog_load,
128                 BPF_CGROUP_INET4_BIND,
129                 BPF_CGROUP_INET4_BIND,
130                 AF_INET,
131                 SOCK_STREAM,
132                 SERV4_IP,
133                 SERV4_PORT,
134                 SERV4_REWRITE_IP,
135                 SERV4_REWRITE_PORT,
136                 NULL,
137                 SUCCESS,
138         },
139         {
140                 "bind4: rewrite IP & UDP port in",
141                 bind4_prog_load,
142                 BPF_CGROUP_INET4_BIND,
143                 BPF_CGROUP_INET4_BIND,
144                 AF_INET,
145                 SOCK_DGRAM,
146                 SERV4_IP,
147                 SERV4_PORT,
148                 SERV4_REWRITE_IP,
149                 SERV4_REWRITE_PORT,
150                 NULL,
151                 SUCCESS,
152         },
153         {
154                 "bind6: load prog with wrong expected attach type",
155                 bind6_prog_load,
156                 BPF_CGROUP_INET4_BIND,
157                 BPF_CGROUP_INET6_BIND,
158                 AF_INET6,
159                 SOCK_STREAM,
160                 NULL,
161                 0,
162                 NULL,
163                 0,
164                 NULL,
165                 LOAD_REJECT,
166         },
167         {
168                 "bind6: attach prog with wrong attach type",
169                 bind6_prog_load,
170                 BPF_CGROUP_INET6_BIND,
171                 BPF_CGROUP_INET4_BIND,
172                 AF_INET,
173                 SOCK_STREAM,
174                 NULL,
175                 0,
176                 NULL,
177                 0,
178                 NULL,
179                 ATTACH_REJECT,
180         },
181         {
182                 "bind6: rewrite IP & TCP port in",
183                 bind6_prog_load,
184                 BPF_CGROUP_INET6_BIND,
185                 BPF_CGROUP_INET6_BIND,
186                 AF_INET6,
187                 SOCK_STREAM,
188                 SERV6_IP,
189                 SERV6_PORT,
190                 SERV6_REWRITE_IP,
191                 SERV6_REWRITE_PORT,
192                 NULL,
193                 SUCCESS,
194         },
195         {
196                 "bind6: rewrite IP & UDP port in",
197                 bind6_prog_load,
198                 BPF_CGROUP_INET6_BIND,
199                 BPF_CGROUP_INET6_BIND,
200                 AF_INET6,
201                 SOCK_DGRAM,
202                 SERV6_IP,
203                 SERV6_PORT,
204                 SERV6_REWRITE_IP,
205                 SERV6_REWRITE_PORT,
206                 NULL,
207                 SUCCESS,
208         },
209
210         /* connect */
211         {
212                 "connect4: load prog with wrong expected attach type",
213                 connect4_prog_load,
214                 BPF_CGROUP_INET6_CONNECT,
215                 BPF_CGROUP_INET4_CONNECT,
216                 AF_INET,
217                 SOCK_STREAM,
218                 NULL,
219                 0,
220                 NULL,
221                 0,
222                 NULL,
223                 LOAD_REJECT,
224         },
225         {
226                 "connect4: attach prog with wrong attach type",
227                 connect4_prog_load,
228                 BPF_CGROUP_INET4_CONNECT,
229                 BPF_CGROUP_INET6_CONNECT,
230                 AF_INET,
231                 SOCK_STREAM,
232                 NULL,
233                 0,
234                 NULL,
235                 0,
236                 NULL,
237                 ATTACH_REJECT,
238         },
239         {
240                 "connect4: rewrite IP & TCP port",
241                 connect4_prog_load,
242                 BPF_CGROUP_INET4_CONNECT,
243                 BPF_CGROUP_INET4_CONNECT,
244                 AF_INET,
245                 SOCK_STREAM,
246                 SERV4_IP,
247                 SERV4_PORT,
248                 SERV4_REWRITE_IP,
249                 SERV4_REWRITE_PORT,
250                 SRC4_REWRITE_IP,
251                 SUCCESS,
252         },
253         {
254                 "connect4: rewrite IP & UDP port",
255                 connect4_prog_load,
256                 BPF_CGROUP_INET4_CONNECT,
257                 BPF_CGROUP_INET4_CONNECT,
258                 AF_INET,
259                 SOCK_DGRAM,
260                 SERV4_IP,
261                 SERV4_PORT,
262                 SERV4_REWRITE_IP,
263                 SERV4_REWRITE_PORT,
264                 SRC4_REWRITE_IP,
265                 SUCCESS,
266         },
267         {
268                 "connect6: load prog with wrong expected attach type",
269                 connect6_prog_load,
270                 BPF_CGROUP_INET4_CONNECT,
271                 BPF_CGROUP_INET6_CONNECT,
272                 AF_INET6,
273                 SOCK_STREAM,
274                 NULL,
275                 0,
276                 NULL,
277                 0,
278                 NULL,
279                 LOAD_REJECT,
280         },
281         {
282                 "connect6: attach prog with wrong attach type",
283                 connect6_prog_load,
284                 BPF_CGROUP_INET6_CONNECT,
285                 BPF_CGROUP_INET4_CONNECT,
286                 AF_INET,
287                 SOCK_STREAM,
288                 NULL,
289                 0,
290                 NULL,
291                 0,
292                 NULL,
293                 ATTACH_REJECT,
294         },
295         {
296                 "connect6: rewrite IP & TCP port",
297                 connect6_prog_load,
298                 BPF_CGROUP_INET6_CONNECT,
299                 BPF_CGROUP_INET6_CONNECT,
300                 AF_INET6,
301                 SOCK_STREAM,
302                 SERV6_IP,
303                 SERV6_PORT,
304                 SERV6_REWRITE_IP,
305                 SERV6_REWRITE_PORT,
306                 SRC6_REWRITE_IP,
307                 SUCCESS,
308         },
309         {
310                 "connect6: rewrite IP & UDP port",
311                 connect6_prog_load,
312                 BPF_CGROUP_INET6_CONNECT,
313                 BPF_CGROUP_INET6_CONNECT,
314                 AF_INET6,
315                 SOCK_DGRAM,
316                 SERV6_IP,
317                 SERV6_PORT,
318                 SERV6_REWRITE_IP,
319                 SERV6_REWRITE_PORT,
320                 SRC6_REWRITE_IP,
321                 SUCCESS,
322         },
323
324         /* sendmsg */
325         {
326                 "sendmsg4: load prog with wrong expected attach type",
327                 sendmsg4_rw_asm_prog_load,
328                 BPF_CGROUP_UDP6_SENDMSG,
329                 BPF_CGROUP_UDP4_SENDMSG,
330                 AF_INET,
331                 SOCK_DGRAM,
332                 NULL,
333                 0,
334                 NULL,
335                 0,
336                 NULL,
337                 LOAD_REJECT,
338         },
339         {
340                 "sendmsg4: attach prog with wrong attach type",
341                 sendmsg4_rw_asm_prog_load,
342                 BPF_CGROUP_UDP4_SENDMSG,
343                 BPF_CGROUP_UDP6_SENDMSG,
344                 AF_INET,
345                 SOCK_DGRAM,
346                 NULL,
347                 0,
348                 NULL,
349                 0,
350                 NULL,
351                 ATTACH_REJECT,
352         },
353         {
354                 "sendmsg4: rewrite IP & port (asm)",
355                 sendmsg4_rw_asm_prog_load,
356                 BPF_CGROUP_UDP4_SENDMSG,
357                 BPF_CGROUP_UDP4_SENDMSG,
358                 AF_INET,
359                 SOCK_DGRAM,
360                 SERV4_IP,
361                 SERV4_PORT,
362                 SERV4_REWRITE_IP,
363                 SERV4_REWRITE_PORT,
364                 SRC4_REWRITE_IP,
365                 SUCCESS,
366         },
367         {
368                 "sendmsg4: rewrite IP & port (C)",
369                 sendmsg4_rw_c_prog_load,
370                 BPF_CGROUP_UDP4_SENDMSG,
371                 BPF_CGROUP_UDP4_SENDMSG,
372                 AF_INET,
373                 SOCK_DGRAM,
374                 SERV4_IP,
375                 SERV4_PORT,
376                 SERV4_REWRITE_IP,
377                 SERV4_REWRITE_PORT,
378                 SRC4_REWRITE_IP,
379                 SUCCESS,
380         },
381         {
382                 "sendmsg4: deny call",
383                 sendmsg_deny_prog_load,
384                 BPF_CGROUP_UDP4_SENDMSG,
385                 BPF_CGROUP_UDP4_SENDMSG,
386                 AF_INET,
387                 SOCK_DGRAM,
388                 SERV4_IP,
389                 SERV4_PORT,
390                 SERV4_REWRITE_IP,
391                 SERV4_REWRITE_PORT,
392                 SRC4_REWRITE_IP,
393                 SYSCALL_EPERM,
394         },
395         {
396                 "sendmsg6: load prog with wrong expected attach type",
397                 sendmsg6_rw_asm_prog_load,
398                 BPF_CGROUP_UDP4_SENDMSG,
399                 BPF_CGROUP_UDP6_SENDMSG,
400                 AF_INET6,
401                 SOCK_DGRAM,
402                 NULL,
403                 0,
404                 NULL,
405                 0,
406                 NULL,
407                 LOAD_REJECT,
408         },
409         {
410                 "sendmsg6: attach prog with wrong attach type",
411                 sendmsg6_rw_asm_prog_load,
412                 BPF_CGROUP_UDP6_SENDMSG,
413                 BPF_CGROUP_UDP4_SENDMSG,
414                 AF_INET6,
415                 SOCK_DGRAM,
416                 NULL,
417                 0,
418                 NULL,
419                 0,
420                 NULL,
421                 ATTACH_REJECT,
422         },
423         {
424                 "sendmsg6: rewrite IP & port (asm)",
425                 sendmsg6_rw_asm_prog_load,
426                 BPF_CGROUP_UDP6_SENDMSG,
427                 BPF_CGROUP_UDP6_SENDMSG,
428                 AF_INET6,
429                 SOCK_DGRAM,
430                 SERV6_IP,
431                 SERV6_PORT,
432                 SERV6_REWRITE_IP,
433                 SERV6_REWRITE_PORT,
434                 SRC6_REWRITE_IP,
435                 SUCCESS,
436         },
437         {
438                 "sendmsg6: rewrite IP & port (C)",
439                 sendmsg6_rw_c_prog_load,
440                 BPF_CGROUP_UDP6_SENDMSG,
441                 BPF_CGROUP_UDP6_SENDMSG,
442                 AF_INET6,
443                 SOCK_DGRAM,
444                 SERV6_IP,
445                 SERV6_PORT,
446                 SERV6_REWRITE_IP,
447                 SERV6_REWRITE_PORT,
448                 SRC6_REWRITE_IP,
449                 SUCCESS,
450         },
451         {
452                 "sendmsg6: IPv4-mapped IPv6",
453                 sendmsg6_rw_v4mapped_prog_load,
454                 BPF_CGROUP_UDP6_SENDMSG,
455                 BPF_CGROUP_UDP6_SENDMSG,
456                 AF_INET6,
457                 SOCK_DGRAM,
458                 SERV6_IP,
459                 SERV6_PORT,
460                 SERV6_REWRITE_IP,
461                 SERV6_REWRITE_PORT,
462                 SRC6_REWRITE_IP,
463                 SYSCALL_ENOTSUPP,
464         },
465         {
466                 "sendmsg6: deny call",
467                 sendmsg_deny_prog_load,
468                 BPF_CGROUP_UDP6_SENDMSG,
469                 BPF_CGROUP_UDP6_SENDMSG,
470                 AF_INET6,
471                 SOCK_DGRAM,
472                 SERV6_IP,
473                 SERV6_PORT,
474                 SERV6_REWRITE_IP,
475                 SERV6_REWRITE_PORT,
476                 SRC6_REWRITE_IP,
477                 SYSCALL_EPERM,
478         },
479 };
480
481 static int mk_sockaddr(int domain, const char *ip, unsigned short port,
482                        struct sockaddr *addr, socklen_t addr_len)
483 {
484         struct sockaddr_in6 *addr6;
485         struct sockaddr_in *addr4;
486
487         if (domain != AF_INET && domain != AF_INET6) {
488                 log_err("Unsupported address family");
489                 return -1;
490         }
491
492         memset(addr, 0, addr_len);
493
494         if (domain == AF_INET) {
495                 if (addr_len < sizeof(struct sockaddr_in))
496                         return -1;
497                 addr4 = (struct sockaddr_in *)addr;
498                 addr4->sin_family = domain;
499                 addr4->sin_port = htons(port);
500                 if (inet_pton(domain, ip, (void *)&addr4->sin_addr) != 1) {
501                         log_err("Invalid IPv4: %s", ip);
502                         return -1;
503                 }
504         } else if (domain == AF_INET6) {
505                 if (addr_len < sizeof(struct sockaddr_in6))
506                         return -1;
507                 addr6 = (struct sockaddr_in6 *)addr;
508                 addr6->sin6_family = domain;
509                 addr6->sin6_port = htons(port);
510                 if (inet_pton(domain, ip, (void *)&addr6->sin6_addr) != 1) {
511                         log_err("Invalid IPv6: %s", ip);
512                         return -1;
513                 }
514         }
515
516         return 0;
517 }
518
519 static int load_insns(const struct sock_addr_test *test,
520                       const struct bpf_insn *insns, size_t insns_cnt)
521 {
522         struct bpf_load_program_attr load_attr;
523         int ret;
524
525         memset(&load_attr, 0, sizeof(struct bpf_load_program_attr));
526         load_attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR;
527         load_attr.expected_attach_type = test->expected_attach_type;
528         load_attr.insns = insns;
529         load_attr.insns_cnt = insns_cnt;
530         load_attr.license = "GPL";
531
532         ret = bpf_load_program_xattr(&load_attr, bpf_log_buf, BPF_LOG_BUF_SIZE);
533         if (ret < 0 && test->expected_result != LOAD_REJECT) {
534                 log_err(">>> Loading program error.\n"
535                         ">>> Verifier output:\n%s\n-------\n", bpf_log_buf);
536         }
537
538         return ret;
539 }
540
541 /* [1] These testing programs try to read different context fields, including
542  * narrow loads of different sizes from user_ip4 and user_ip6, and write to
543  * those allowed to be overridden.
544  *
545  * [2] BPF_LD_IMM64 & BPF_JMP_REG are used below whenever there is a need to
546  * compare a register with unsigned 32bit integer. BPF_JMP_IMM can't be used
547  * in such cases since it accepts only _signed_ 32bit integer as IMM
548  * argument. Also note that BPF_LD_IMM64 contains 2 instructions what matters
549  * to count jumps properly.
550  */
551
552 static int bind4_prog_load(const struct sock_addr_test *test)
553 {
554         union {
555                 uint8_t u4_addr8[4];
556                 uint16_t u4_addr16[2];
557                 uint32_t u4_addr32;
558         } ip4;
559         struct sockaddr_in addr4_rw;
560
561         if (inet_pton(AF_INET, SERV4_IP, (void *)&ip4) != 1) {
562                 log_err("Invalid IPv4: %s", SERV4_IP);
563                 return -1;
564         }
565
566         if (mk_sockaddr(AF_INET, SERV4_REWRITE_IP, SERV4_REWRITE_PORT,
567                         (struct sockaddr *)&addr4_rw, sizeof(addr4_rw)) == -1)
568                 return -1;
569
570         /* See [1]. */
571         struct bpf_insn insns[] = {
572                 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
573
574                 /* if (sk.family == AF_INET && */
575                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
576                             offsetof(struct bpf_sock_addr, family)),
577                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET, 24),
578
579                 /*     (sk.type == SOCK_DGRAM || sk.type == SOCK_STREAM) && */
580                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
581                             offsetof(struct bpf_sock_addr, type)),
582                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_DGRAM, 1),
583                 BPF_JMP_A(1),
584                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_STREAM, 20),
585
586                 /*     1st_byte_of_user_ip4 == expected && */
587                 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
588                             offsetof(struct bpf_sock_addr, user_ip4)),
589                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[0], 18),
590
591                 /*     2nd_byte_of_user_ip4 == expected && */
592                 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
593                             offsetof(struct bpf_sock_addr, user_ip4) + 1),
594                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[1], 16),
595
596                 /*     3rd_byte_of_user_ip4 == expected && */
597                 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
598                             offsetof(struct bpf_sock_addr, user_ip4) + 2),
599                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[2], 14),
600
601                 /*     4th_byte_of_user_ip4 == expected && */
602                 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
603                             offsetof(struct bpf_sock_addr, user_ip4) + 3),
604                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[3], 12),
605
606                 /*     1st_half_of_user_ip4 == expected && */
607                 BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
608                             offsetof(struct bpf_sock_addr, user_ip4)),
609                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr16[0], 10),
610
611                 /*     2nd_half_of_user_ip4 == expected && */
612                 BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
613                             offsetof(struct bpf_sock_addr, user_ip4) + 2),
614                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr16[1], 8),
615
616                 /*     whole_user_ip4 == expected) { */
617                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
618                             offsetof(struct bpf_sock_addr, user_ip4)),
619                 BPF_LD_IMM64(BPF_REG_8, ip4.u4_addr32), /* See [2]. */
620                 BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 4),
621
622                 /*      user_ip4 = addr4_rw.sin_addr */
623                 BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_addr.s_addr),
624                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
625                             offsetof(struct bpf_sock_addr, user_ip4)),
626
627                 /*      user_port = addr4_rw.sin_port */
628                 BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_port),
629                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
630                             offsetof(struct bpf_sock_addr, user_port)),
631                 /* } */
632
633                 /* return 1 */
634                 BPF_MOV64_IMM(BPF_REG_0, 1),
635                 BPF_EXIT_INSN(),
636         };
637
638         return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
639 }
640
641 static int bind6_prog_load(const struct sock_addr_test *test)
642 {
643         struct sockaddr_in6 addr6_rw;
644         struct in6_addr ip6;
645
646         if (inet_pton(AF_INET6, SERV6_IP, (void *)&ip6) != 1) {
647                 log_err("Invalid IPv6: %s", SERV6_IP);
648                 return -1;
649         }
650
651         if (mk_sockaddr(AF_INET6, SERV6_REWRITE_IP, SERV6_REWRITE_PORT,
652                         (struct sockaddr *)&addr6_rw, sizeof(addr6_rw)) == -1)
653                 return -1;
654
655         /* See [1]. */
656         struct bpf_insn insns[] = {
657                 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
658
659                 /* if (sk.family == AF_INET6 && */
660                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
661                             offsetof(struct bpf_sock_addr, family)),
662                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET6, 18),
663
664                 /*            5th_byte_of_user_ip6 == expected && */
665                 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
666                             offsetof(struct bpf_sock_addr, user_ip6[1])),
667                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr[4], 16),
668
669                 /*            3rd_half_of_user_ip6 == expected && */
670                 BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
671                             offsetof(struct bpf_sock_addr, user_ip6[1])),
672                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr16[2], 14),
673
674                 /*            last_word_of_user_ip6 == expected) { */
675                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
676                             offsetof(struct bpf_sock_addr, user_ip6[3])),
677                 BPF_LD_IMM64(BPF_REG_8, ip6.s6_addr32[3]),  /* See [2]. */
678                 BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 10),
679
680
681 #define STORE_IPV6_WORD(N)                                                     \
682                 BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_addr.s6_addr32[N]),     \
683                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,                       \
684                             offsetof(struct bpf_sock_addr, user_ip6[N]))
685
686                 /*      user_ip6 = addr6_rw.sin6_addr */
687                 STORE_IPV6_WORD(0),
688                 STORE_IPV6_WORD(1),
689                 STORE_IPV6_WORD(2),
690                 STORE_IPV6_WORD(3),
691
692                 /*      user_port = addr6_rw.sin6_port */
693                 BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_port),
694                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
695                             offsetof(struct bpf_sock_addr, user_port)),
696
697                 /* } */
698
699                 /* return 1 */
700                 BPF_MOV64_IMM(BPF_REG_0, 1),
701                 BPF_EXIT_INSN(),
702         };
703
704         return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
705 }
706
707 static int load_path(const struct sock_addr_test *test, const char *path)
708 {
709         struct bpf_prog_load_attr attr;
710         struct bpf_object *obj;
711         int prog_fd;
712
713         memset(&attr, 0, sizeof(struct bpf_prog_load_attr));
714         attr.file = path;
715         attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR;
716         attr.expected_attach_type = test->expected_attach_type;
717
718         if (bpf_prog_load_xattr(&attr, &obj, &prog_fd)) {
719                 if (test->expected_result != LOAD_REJECT)
720                         log_err(">>> Loading program (%s) error.\n", path);
721                 return -1;
722         }
723
724         return prog_fd;
725 }
726
727 static int connect4_prog_load(const struct sock_addr_test *test)
728 {
729         return load_path(test, CONNECT4_PROG_PATH);
730 }
731
732 static int connect6_prog_load(const struct sock_addr_test *test)
733 {
734         return load_path(test, CONNECT6_PROG_PATH);
735 }
736
737 static int sendmsg_deny_prog_load(const struct sock_addr_test *test)
738 {
739         struct bpf_insn insns[] = {
740                 /* return 0 */
741                 BPF_MOV64_IMM(BPF_REG_0, 0),
742                 BPF_EXIT_INSN(),
743         };
744         return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
745 }
746
747 static int sendmsg4_rw_asm_prog_load(const struct sock_addr_test *test)
748 {
749         struct sockaddr_in dst4_rw_addr;
750         struct in_addr src4_rw_ip;
751
752         if (inet_pton(AF_INET, SRC4_REWRITE_IP, (void *)&src4_rw_ip) != 1) {
753                 log_err("Invalid IPv4: %s", SRC4_REWRITE_IP);
754                 return -1;
755         }
756
757         if (mk_sockaddr(AF_INET, SERV4_REWRITE_IP, SERV4_REWRITE_PORT,
758                         (struct sockaddr *)&dst4_rw_addr,
759                         sizeof(dst4_rw_addr)) == -1)
760                 return -1;
761
762         struct bpf_insn insns[] = {
763                 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
764
765                 /* if (sk.family == AF_INET && */
766                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
767                             offsetof(struct bpf_sock_addr, family)),
768                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET, 8),
769
770                 /*     sk.type == SOCK_DGRAM)  { */
771                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
772                             offsetof(struct bpf_sock_addr, type)),
773                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_DGRAM, 6),
774
775                 /*      msg_src_ip4 = src4_rw_ip */
776                 BPF_MOV32_IMM(BPF_REG_7, src4_rw_ip.s_addr),
777                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
778                             offsetof(struct bpf_sock_addr, msg_src_ip4)),
779
780                 /*      user_ip4 = dst4_rw_addr.sin_addr */
781                 BPF_MOV32_IMM(BPF_REG_7, dst4_rw_addr.sin_addr.s_addr),
782                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
783                             offsetof(struct bpf_sock_addr, user_ip4)),
784
785                 /*      user_port = dst4_rw_addr.sin_port */
786                 BPF_MOV32_IMM(BPF_REG_7, dst4_rw_addr.sin_port),
787                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
788                             offsetof(struct bpf_sock_addr, user_port)),
789                 /* } */
790
791                 /* return 1 */
792                 BPF_MOV64_IMM(BPF_REG_0, 1),
793                 BPF_EXIT_INSN(),
794         };
795
796         return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
797 }
798
799 static int sendmsg4_rw_c_prog_load(const struct sock_addr_test *test)
800 {
801         return load_path(test, SENDMSG4_PROG_PATH);
802 }
803
804 static int sendmsg6_rw_dst_asm_prog_load(const struct sock_addr_test *test,
805                                          const char *rw_dst_ip)
806 {
807         struct sockaddr_in6 dst6_rw_addr;
808         struct in6_addr src6_rw_ip;
809
810         if (inet_pton(AF_INET6, SRC6_REWRITE_IP, (void *)&src6_rw_ip) != 1) {
811                 log_err("Invalid IPv6: %s", SRC6_REWRITE_IP);
812                 return -1;
813         }
814
815         if (mk_sockaddr(AF_INET6, rw_dst_ip, SERV6_REWRITE_PORT,
816                         (struct sockaddr *)&dst6_rw_addr,
817                         sizeof(dst6_rw_addr)) == -1)
818                 return -1;
819
820         struct bpf_insn insns[] = {
821                 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
822
823                 /* if (sk.family == AF_INET6) { */
824                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
825                             offsetof(struct bpf_sock_addr, family)),
826                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET6, 18),
827
828 #define STORE_IPV6_WORD_N(DST, SRC, N)                                         \
829                 BPF_MOV32_IMM(BPF_REG_7, SRC[N]),                              \
830                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,                       \
831                             offsetof(struct bpf_sock_addr, DST[N]))
832
833 #define STORE_IPV6(DST, SRC)                                                   \
834                 STORE_IPV6_WORD_N(DST, SRC, 0),                                \
835                 STORE_IPV6_WORD_N(DST, SRC, 1),                                \
836                 STORE_IPV6_WORD_N(DST, SRC, 2),                                \
837                 STORE_IPV6_WORD_N(DST, SRC, 3)
838
839                 STORE_IPV6(msg_src_ip6, src6_rw_ip.s6_addr32),
840                 STORE_IPV6(user_ip6, dst6_rw_addr.sin6_addr.s6_addr32),
841
842                 /*      user_port = dst6_rw_addr.sin6_port */
843                 BPF_MOV32_IMM(BPF_REG_7, dst6_rw_addr.sin6_port),
844                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
845                             offsetof(struct bpf_sock_addr, user_port)),
846
847                 /* } */
848
849                 /* return 1 */
850                 BPF_MOV64_IMM(BPF_REG_0, 1),
851                 BPF_EXIT_INSN(),
852         };
853
854         return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
855 }
856
857 static int sendmsg6_rw_asm_prog_load(const struct sock_addr_test *test)
858 {
859         return sendmsg6_rw_dst_asm_prog_load(test, SERV6_REWRITE_IP);
860 }
861
862 static int sendmsg6_rw_v4mapped_prog_load(const struct sock_addr_test *test)
863 {
864         return sendmsg6_rw_dst_asm_prog_load(test, SERV6_V4MAPPED_IP);
865 }
866
867 static int sendmsg6_rw_c_prog_load(const struct sock_addr_test *test)
868 {
869         return load_path(test, SENDMSG6_PROG_PATH);
870 }
871
872 static int cmp_addr(const struct sockaddr_storage *addr1,
873                     const struct sockaddr_storage *addr2, int cmp_port)
874 {
875         const struct sockaddr_in *four1, *four2;
876         const struct sockaddr_in6 *six1, *six2;
877
878         if (addr1->ss_family != addr2->ss_family)
879                 return -1;
880
881         if (addr1->ss_family == AF_INET) {
882                 four1 = (const struct sockaddr_in *)addr1;
883                 four2 = (const struct sockaddr_in *)addr2;
884                 return !((four1->sin_port == four2->sin_port || !cmp_port) &&
885                          four1->sin_addr.s_addr == four2->sin_addr.s_addr);
886         } else if (addr1->ss_family == AF_INET6) {
887                 six1 = (const struct sockaddr_in6 *)addr1;
888                 six2 = (const struct sockaddr_in6 *)addr2;
889                 return !((six1->sin6_port == six2->sin6_port || !cmp_port) &&
890                          !memcmp(&six1->sin6_addr, &six2->sin6_addr,
891                                  sizeof(struct in6_addr)));
892         }
893
894         return -1;
895 }
896
897 static int cmp_sock_addr(info_fn fn, int sock1,
898                          const struct sockaddr_storage *addr2, int cmp_port)
899 {
900         struct sockaddr_storage addr1;
901         socklen_t len1 = sizeof(addr1);
902
903         memset(&addr1, 0, len1);
904         if (fn(sock1, (struct sockaddr *)&addr1, (socklen_t *)&len1) != 0)
905                 return -1;
906
907         return cmp_addr(&addr1, addr2, cmp_port);
908 }
909
910 static int cmp_local_ip(int sock1, const struct sockaddr_storage *addr2)
911 {
912         return cmp_sock_addr(getsockname, sock1, addr2, /*cmp_port*/ 0);
913 }
914
915 static int cmp_local_addr(int sock1, const struct sockaddr_storage *addr2)
916 {
917         return cmp_sock_addr(getsockname, sock1, addr2, /*cmp_port*/ 1);
918 }
919
920 static int cmp_peer_addr(int sock1, const struct sockaddr_storage *addr2)
921 {
922         return cmp_sock_addr(getpeername, sock1, addr2, /*cmp_port*/ 1);
923 }
924
925 static int start_server(int type, const struct sockaddr_storage *addr,
926                         socklen_t addr_len)
927 {
928         int fd;
929
930         fd = socket(addr->ss_family, type, 0);
931         if (fd == -1) {
932                 log_err("Failed to create server socket");
933                 goto out;
934         }
935
936         if (bind(fd, (const struct sockaddr *)addr, addr_len) == -1) {
937                 log_err("Failed to bind server socket");
938                 goto close_out;
939         }
940
941         if (type == SOCK_STREAM) {
942                 if (listen(fd, 128) == -1) {
943                         log_err("Failed to listen on server socket");
944                         goto close_out;
945                 }
946         }
947
948         goto out;
949 close_out:
950         close(fd);
951         fd = -1;
952 out:
953         return fd;
954 }
955
956 static int connect_to_server(int type, const struct sockaddr_storage *addr,
957                              socklen_t addr_len)
958 {
959         int domain;
960         int fd = -1;
961
962         domain = addr->ss_family;
963
964         if (domain != AF_INET && domain != AF_INET6) {
965                 log_err("Unsupported address family");
966                 goto err;
967         }
968
969         fd = socket(domain, type, 0);
970         if (fd == -1) {
971                 log_err("Failed to create client socket");
972                 goto err;
973         }
974
975         if (connect(fd, (const struct sockaddr *)addr, addr_len) == -1) {
976                 log_err("Fail to connect to server");
977                 goto err;
978         }
979
980         goto out;
981 err:
982         close(fd);
983         fd = -1;
984 out:
985         return fd;
986 }
987
988 int init_pktinfo(int domain, struct cmsghdr *cmsg)
989 {
990         struct in6_pktinfo *pktinfo6;
991         struct in_pktinfo *pktinfo4;
992
993         if (domain == AF_INET) {
994                 cmsg->cmsg_level = SOL_IP;
995                 cmsg->cmsg_type = IP_PKTINFO;
996                 cmsg->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
997                 pktinfo4 = (struct in_pktinfo *)CMSG_DATA(cmsg);
998                 memset(pktinfo4, 0, sizeof(struct in_pktinfo));
999                 if (inet_pton(domain, SRC4_IP,
1000                               (void *)&pktinfo4->ipi_spec_dst) != 1)
1001                         return -1;
1002         } else if (domain == AF_INET6) {
1003                 cmsg->cmsg_level = SOL_IPV6;
1004                 cmsg->cmsg_type = IPV6_PKTINFO;
1005                 cmsg->cmsg_len = CMSG_LEN(sizeof(struct in6_pktinfo));
1006                 pktinfo6 = (struct in6_pktinfo *)CMSG_DATA(cmsg);
1007                 memset(pktinfo6, 0, sizeof(struct in6_pktinfo));
1008                 if (inet_pton(domain, SRC6_IP,
1009                               (void *)&pktinfo6->ipi6_addr) != 1)
1010                         return -1;
1011         } else {
1012                 return -1;
1013         }
1014
1015         return 0;
1016 }
1017
1018 static int sendmsg_to_server(int type, const struct sockaddr_storage *addr,
1019                              socklen_t addr_len, int set_cmsg, int flags,
1020                              int *syscall_err)
1021 {
1022         union {
1023                 char buf[CMSG_SPACE(sizeof(struct in6_pktinfo))];
1024                 struct cmsghdr align;
1025         } control6;
1026         union {
1027                 char buf[CMSG_SPACE(sizeof(struct in_pktinfo))];
1028                 struct cmsghdr align;
1029         } control4;
1030         struct msghdr hdr;
1031         struct iovec iov;
1032         char data = 'a';
1033         int domain;
1034         int fd = -1;
1035
1036         domain = addr->ss_family;
1037
1038         if (domain != AF_INET && domain != AF_INET6) {
1039                 log_err("Unsupported address family");
1040                 goto err;
1041         }
1042
1043         fd = socket(domain, type, 0);
1044         if (fd == -1) {
1045                 log_err("Failed to create client socket");
1046                 goto err;
1047         }
1048
1049         memset(&iov, 0, sizeof(iov));
1050         iov.iov_base = &data;
1051         iov.iov_len = sizeof(data);
1052
1053         memset(&hdr, 0, sizeof(hdr));
1054         hdr.msg_name = (void *)addr;
1055         hdr.msg_namelen = addr_len;
1056         hdr.msg_iov = &iov;
1057         hdr.msg_iovlen = 1;
1058
1059         if (set_cmsg) {
1060                 if (domain == AF_INET) {
1061                         hdr.msg_control = &control4;
1062                         hdr.msg_controllen = sizeof(control4.buf);
1063                 } else if (domain == AF_INET6) {
1064                         hdr.msg_control = &control6;
1065                         hdr.msg_controllen = sizeof(control6.buf);
1066                 }
1067                 if (init_pktinfo(domain, CMSG_FIRSTHDR(&hdr))) {
1068                         log_err("Fail to init pktinfo");
1069                         goto err;
1070                 }
1071         }
1072
1073         if (sendmsg(fd, &hdr, flags) != sizeof(data)) {
1074                 log_err("Fail to send message to server");
1075                 *syscall_err = errno;
1076                 goto err;
1077         }
1078
1079         goto out;
1080 err:
1081         close(fd);
1082         fd = -1;
1083 out:
1084         return fd;
1085 }
1086
1087 static int fastconnect_to_server(const struct sockaddr_storage *addr,
1088                                  socklen_t addr_len)
1089 {
1090         int sendmsg_err;
1091
1092         return sendmsg_to_server(SOCK_STREAM, addr, addr_len, /*set_cmsg*/0,
1093                                  MSG_FASTOPEN, &sendmsg_err);
1094 }
1095
1096 static int recvmsg_from_client(int sockfd, struct sockaddr_storage *src_addr)
1097 {
1098         struct timeval tv;
1099         struct msghdr hdr;
1100         struct iovec iov;
1101         char data[64];
1102         fd_set rfds;
1103
1104         FD_ZERO(&rfds);
1105         FD_SET(sockfd, &rfds);
1106
1107         tv.tv_sec = 2;
1108         tv.tv_usec = 0;
1109
1110         if (select(sockfd + 1, &rfds, NULL, NULL, &tv) <= 0 ||
1111             !FD_ISSET(sockfd, &rfds))
1112                 return -1;
1113
1114         memset(&iov, 0, sizeof(iov));
1115         iov.iov_base = data;
1116         iov.iov_len = sizeof(data);
1117
1118         memset(&hdr, 0, sizeof(hdr));
1119         hdr.msg_name = src_addr;
1120         hdr.msg_namelen = sizeof(struct sockaddr_storage);
1121         hdr.msg_iov = &iov;
1122         hdr.msg_iovlen = 1;
1123
1124         return recvmsg(sockfd, &hdr, 0);
1125 }
1126
1127 static int init_addrs(const struct sock_addr_test *test,
1128                       struct sockaddr_storage *requested_addr,
1129                       struct sockaddr_storage *expected_addr,
1130                       struct sockaddr_storage *expected_src_addr)
1131 {
1132         socklen_t addr_len = sizeof(struct sockaddr_storage);
1133
1134         if (mk_sockaddr(test->domain, test->expected_ip, test->expected_port,
1135                         (struct sockaddr *)expected_addr, addr_len) == -1)
1136                 goto err;
1137
1138         if (mk_sockaddr(test->domain, test->requested_ip, test->requested_port,
1139                         (struct sockaddr *)requested_addr, addr_len) == -1)
1140                 goto err;
1141
1142         if (test->expected_src_ip &&
1143             mk_sockaddr(test->domain, test->expected_src_ip, 0,
1144                         (struct sockaddr *)expected_src_addr, addr_len) == -1)
1145                 goto err;
1146
1147         return 0;
1148 err:
1149         return -1;
1150 }
1151
1152 static int run_bind_test_case(const struct sock_addr_test *test)
1153 {
1154         socklen_t addr_len = sizeof(struct sockaddr_storage);
1155         struct sockaddr_storage requested_addr;
1156         struct sockaddr_storage expected_addr;
1157         int clientfd = -1;
1158         int servfd = -1;
1159         int err = 0;
1160
1161         if (init_addrs(test, &requested_addr, &expected_addr, NULL))
1162                 goto err;
1163
1164         servfd = start_server(test->type, &requested_addr, addr_len);
1165         if (servfd == -1)
1166                 goto err;
1167
1168         if (cmp_local_addr(servfd, &expected_addr))
1169                 goto err;
1170
1171         /* Try to connect to server just in case */
1172         clientfd = connect_to_server(test->type, &expected_addr, addr_len);
1173         if (clientfd == -1)
1174                 goto err;
1175
1176         goto out;
1177 err:
1178         err = -1;
1179 out:
1180         close(clientfd);
1181         close(servfd);
1182         return err;
1183 }
1184
1185 static int run_connect_test_case(const struct sock_addr_test *test)
1186 {
1187         socklen_t addr_len = sizeof(struct sockaddr_storage);
1188         struct sockaddr_storage expected_src_addr;
1189         struct sockaddr_storage requested_addr;
1190         struct sockaddr_storage expected_addr;
1191         int clientfd = -1;
1192         int servfd = -1;
1193         int err = 0;
1194
1195         if (init_addrs(test, &requested_addr, &expected_addr,
1196                        &expected_src_addr))
1197                 goto err;
1198
1199         /* Prepare server to connect to */
1200         servfd = start_server(test->type, &expected_addr, addr_len);
1201         if (servfd == -1)
1202                 goto err;
1203
1204         clientfd = connect_to_server(test->type, &requested_addr, addr_len);
1205         if (clientfd == -1)
1206                 goto err;
1207
1208         /* Make sure src and dst addrs were overridden properly */
1209         if (cmp_peer_addr(clientfd, &expected_addr))
1210                 goto err;
1211
1212         if (cmp_local_ip(clientfd, &expected_src_addr))
1213                 goto err;
1214
1215         if (test->type == SOCK_STREAM) {
1216                 /* Test TCP Fast Open scenario */
1217                 clientfd = fastconnect_to_server(&requested_addr, addr_len);
1218                 if (clientfd == -1)
1219                         goto err;
1220
1221                 /* Make sure src and dst addrs were overridden properly */
1222                 if (cmp_peer_addr(clientfd, &expected_addr))
1223                         goto err;
1224
1225                 if (cmp_local_ip(clientfd, &expected_src_addr))
1226                         goto err;
1227         }
1228
1229         goto out;
1230 err:
1231         err = -1;
1232 out:
1233         close(clientfd);
1234         close(servfd);
1235         return err;
1236 }
1237
1238 static int run_sendmsg_test_case(const struct sock_addr_test *test)
1239 {
1240         socklen_t addr_len = sizeof(struct sockaddr_storage);
1241         struct sockaddr_storage expected_src_addr;
1242         struct sockaddr_storage requested_addr;
1243         struct sockaddr_storage expected_addr;
1244         struct sockaddr_storage real_src_addr;
1245         int clientfd = -1;
1246         int servfd = -1;
1247         int set_cmsg;
1248         int err = 0;
1249
1250         if (test->type != SOCK_DGRAM)
1251                 goto err;
1252
1253         if (init_addrs(test, &requested_addr, &expected_addr,
1254                        &expected_src_addr))
1255                 goto err;
1256
1257         /* Prepare server to sendmsg to */
1258         servfd = start_server(test->type, &expected_addr, addr_len);
1259         if (servfd == -1)
1260                 goto err;
1261
1262         for (set_cmsg = 0; set_cmsg <= 1; ++set_cmsg) {
1263                 if (clientfd >= 0)
1264                         close(clientfd);
1265
1266                 clientfd = sendmsg_to_server(test->type, &requested_addr,
1267                                              addr_len, set_cmsg, /*flags*/0,
1268                                              &err);
1269                 if (err)
1270                         goto out;
1271                 else if (clientfd == -1)
1272                         goto err;
1273
1274                 /* Try to receive message on server instead of using
1275                  * getpeername(2) on client socket, to check that client's
1276                  * destination address was rewritten properly, since
1277                  * getpeername(2) doesn't work with unconnected datagram
1278                  * sockets.
1279                  *
1280                  * Get source address from recvmsg(2) as well to make sure
1281                  * source was rewritten properly: getsockname(2) can't be used
1282                  * since socket is unconnected and source defined for one
1283                  * specific packet may differ from the one used by default and
1284                  * returned by getsockname(2).
1285                  */
1286                 if (recvmsg_from_client(servfd, &real_src_addr) == -1)
1287                         goto err;
1288
1289                 if (cmp_addr(&real_src_addr, &expected_src_addr, /*cmp_port*/0))
1290                         goto err;
1291         }
1292
1293         goto out;
1294 err:
1295         err = -1;
1296 out:
1297         close(clientfd);
1298         close(servfd);
1299         return err;
1300 }
1301
1302 static int run_test_case(int cgfd, const struct sock_addr_test *test)
1303 {
1304         int progfd = -1;
1305         int err = 0;
1306
1307         printf("Test case: %s .. ", test->descr);
1308
1309         progfd = test->loadfn(test);
1310         if (test->expected_result == LOAD_REJECT && progfd < 0)
1311                 goto out;
1312         else if (test->expected_result == LOAD_REJECT || progfd < 0)
1313                 goto err;
1314
1315         err = bpf_prog_attach(progfd, cgfd, test->attach_type,
1316                               BPF_F_ALLOW_OVERRIDE);
1317         if (test->expected_result == ATTACH_REJECT && err) {
1318                 err = 0; /* error was expected, reset it */
1319                 goto out;
1320         } else if (test->expected_result == ATTACH_REJECT || err) {
1321                 goto err;
1322         }
1323
1324         switch (test->attach_type) {
1325         case BPF_CGROUP_INET4_BIND:
1326         case BPF_CGROUP_INET6_BIND:
1327                 err = run_bind_test_case(test);
1328                 break;
1329         case BPF_CGROUP_INET4_CONNECT:
1330         case BPF_CGROUP_INET6_CONNECT:
1331                 err = run_connect_test_case(test);
1332                 break;
1333         case BPF_CGROUP_UDP4_SENDMSG:
1334         case BPF_CGROUP_UDP6_SENDMSG:
1335                 err = run_sendmsg_test_case(test);
1336                 break;
1337         default:
1338                 goto err;
1339         }
1340
1341         if (test->expected_result == SYSCALL_EPERM && err == EPERM) {
1342                 err = 0; /* error was expected, reset it */
1343                 goto out;
1344         }
1345
1346         if (test->expected_result == SYSCALL_ENOTSUPP && err == ENOTSUPP) {
1347                 err = 0; /* error was expected, reset it */
1348                 goto out;
1349         }
1350
1351         if (err || test->expected_result != SUCCESS)
1352                 goto err;
1353
1354         goto out;
1355 err:
1356         err = -1;
1357 out:
1358         /* Detaching w/o checking return code: best effort attempt. */
1359         if (progfd != -1)
1360                 bpf_prog_detach(cgfd, test->attach_type);
1361         close(progfd);
1362         printf("[%s]\n", err ? "FAIL" : "PASS");
1363         return err;
1364 }
1365
1366 static int run_tests(int cgfd)
1367 {
1368         int passes = 0;
1369         int fails = 0;
1370         int i;
1371
1372         for (i = 0; i < ARRAY_SIZE(tests); ++i) {
1373                 if (run_test_case(cgfd, &tests[i]))
1374                         ++fails;
1375                 else
1376                         ++passes;
1377         }
1378         printf("Summary: %d PASSED, %d FAILED\n", passes, fails);
1379         return fails ? -1 : 0;
1380 }
1381
1382 int main(int argc, char **argv)
1383 {
1384         int cgfd = -1;
1385         int err = 0;
1386
1387         if (argc < 2) {
1388                 fprintf(stderr,
1389                         "%s has to be run via %s.sh. Skip direct run.\n",
1390                         argv[0], argv[0]);
1391                 exit(err);
1392         }
1393
1394         if (setup_cgroup_environment())
1395                 goto err;
1396
1397         cgfd = create_and_get_cgroup(CG_PATH);
1398         if (!cgfd)
1399                 goto err;
1400
1401         if (join_cgroup(CG_PATH))
1402                 goto err;
1403
1404         if (run_tests(cgfd))
1405                 goto err;
1406
1407         goto out;
1408 err:
1409         err = -1;
1410 out:
1411         close(cgfd);
1412         cleanup_cgroup_environment();
1413         return err;
1414 }