Linux 6.9-rc5
[sfrench/cifs-2.6.git] / tools / testing / selftests / bpf / prog_tests / bpf_tcp_ca.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook */
3
4 #include <linux/err.h>
5 #include <test_progs.h>
6 #include "bpf_dctcp.skel.h"
7 #include "bpf_cubic.skel.h"
8
9 #define min(a, b) ((a) < (b) ? (a) : (b))
10
11 static const unsigned int total_bytes = 10 * 1024 * 1024;
12 static const struct timeval timeo_sec = { .tv_sec = 10 };
13 static const size_t timeo_optlen = sizeof(timeo_sec);
14 static int stop, duration;
15
16 static int settimeo(int fd)
17 {
18         int err;
19
20         err = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeo_sec,
21                          timeo_optlen);
22         if (CHECK(err == -1, "setsockopt(fd, SO_RCVTIMEO)", "errno:%d\n",
23                   errno))
24                 return -1;
25
26         err = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeo_sec,
27                          timeo_optlen);
28         if (CHECK(err == -1, "setsockopt(fd, SO_SNDTIMEO)", "errno:%d\n",
29                   errno))
30                 return -1;
31
32         return 0;
33 }
34
35 static int settcpca(int fd, const char *tcp_ca)
36 {
37         int err;
38
39         err = setsockopt(fd, IPPROTO_TCP, TCP_CONGESTION, tcp_ca, strlen(tcp_ca));
40         if (CHECK(err == -1, "setsockopt(fd, TCP_CONGESTION)", "errno:%d\n",
41                   errno))
42                 return -1;
43
44         return 0;
45 }
46
47 static void *server(void *arg)
48 {
49         int lfd = (int)(long)arg, err = 0, fd;
50         ssize_t nr_sent = 0, bytes = 0;
51         char batch[1500];
52
53         fd = accept(lfd, NULL, NULL);
54         while (fd == -1) {
55                 if (errno == EINTR)
56                         continue;
57                 err = -errno;
58                 goto done;
59         }
60
61         if (settimeo(fd)) {
62                 err = -errno;
63                 goto done;
64         }
65
66         while (bytes < total_bytes && !READ_ONCE(stop)) {
67                 nr_sent = send(fd, &batch,
68                                min(total_bytes - bytes, sizeof(batch)), 0);
69                 if (nr_sent == -1 && errno == EINTR)
70                         continue;
71                 if (nr_sent == -1) {
72                         err = -errno;
73                         break;
74                 }
75                 bytes += nr_sent;
76         }
77
78         CHECK(bytes != total_bytes, "send", "%zd != %u nr_sent:%zd errno:%d\n",
79               bytes, total_bytes, nr_sent, errno);
80
81 done:
82         if (fd != -1)
83                 close(fd);
84         if (err) {
85                 WRITE_ONCE(stop, 1);
86                 return ERR_PTR(err);
87         }
88         return NULL;
89 }
90
91 static void do_test(const char *tcp_ca)
92 {
93         struct sockaddr_in6 sa6 = {};
94         ssize_t nr_recv = 0, bytes = 0;
95         int lfd = -1, fd = -1;
96         pthread_t srv_thread;
97         socklen_t addrlen = sizeof(sa6);
98         void *thread_ret;
99         char batch[1500];
100         int err;
101
102         WRITE_ONCE(stop, 0);
103
104         lfd = socket(AF_INET6, SOCK_STREAM, 0);
105         if (CHECK(lfd == -1, "socket", "errno:%d\n", errno))
106                 return;
107         fd = socket(AF_INET6, SOCK_STREAM, 0);
108         if (CHECK(fd == -1, "socket", "errno:%d\n", errno)) {
109                 close(lfd);
110                 return;
111         }
112
113         if (settcpca(lfd, tcp_ca) || settcpca(fd, tcp_ca) ||
114             settimeo(lfd) || settimeo(fd))
115                 goto done;
116
117         /* bind, listen and start server thread to accept */
118         sa6.sin6_family = AF_INET6;
119         sa6.sin6_addr = in6addr_loopback;
120         err = bind(lfd, (struct sockaddr *)&sa6, addrlen);
121         if (CHECK(err == -1, "bind", "errno:%d\n", errno))
122                 goto done;
123         err = getsockname(lfd, (struct sockaddr *)&sa6, &addrlen);
124         if (CHECK(err == -1, "getsockname", "errno:%d\n", errno))
125                 goto done;
126         err = listen(lfd, 1);
127         if (CHECK(err == -1, "listen", "errno:%d\n", errno))
128                 goto done;
129         err = pthread_create(&srv_thread, NULL, server, (void *)(long)lfd);
130         if (CHECK(err != 0, "pthread_create", "err:%d\n", err))
131                 goto done;
132
133         /* connect to server */
134         err = connect(fd, (struct sockaddr *)&sa6, addrlen);
135         if (CHECK(err == -1, "connect", "errno:%d\n", errno))
136                 goto wait_thread;
137
138         /* recv total_bytes */
139         while (bytes < total_bytes && !READ_ONCE(stop)) {
140                 nr_recv = recv(fd, &batch,
141                                min(total_bytes - bytes, sizeof(batch)), 0);
142                 if (nr_recv == -1 && errno == EINTR)
143                         continue;
144                 if (nr_recv == -1)
145                         break;
146                 bytes += nr_recv;
147         }
148
149         CHECK(bytes != total_bytes, "recv", "%zd != %u nr_recv:%zd errno:%d\n",
150               bytes, total_bytes, nr_recv, errno);
151
152 wait_thread:
153         WRITE_ONCE(stop, 1);
154         pthread_join(srv_thread, &thread_ret);
155         CHECK(IS_ERR(thread_ret), "pthread_join", "thread_ret:%ld",
156               PTR_ERR(thread_ret));
157 done:
158         close(lfd);
159         close(fd);
160 }
161
162 static void test_cubic(void)
163 {
164         struct bpf_cubic *cubic_skel;
165         struct bpf_link *link;
166
167         cubic_skel = bpf_cubic__open_and_load();
168         if (CHECK(!cubic_skel, "bpf_cubic__open_and_load", "failed\n"))
169                 return;
170
171         link = bpf_map__attach_struct_ops(cubic_skel->maps.cubic);
172         if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n",
173                   PTR_ERR(link))) {
174                 bpf_cubic__destroy(cubic_skel);
175                 return;
176         }
177
178         do_test("bpf_cubic");
179
180         bpf_link__destroy(link);
181         bpf_cubic__destroy(cubic_skel);
182 }
183
184 static void test_dctcp(void)
185 {
186         struct bpf_dctcp *dctcp_skel;
187         struct bpf_link *link;
188
189         dctcp_skel = bpf_dctcp__open_and_load();
190         if (CHECK(!dctcp_skel, "bpf_dctcp__open_and_load", "failed\n"))
191                 return;
192
193         link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp);
194         if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n",
195                   PTR_ERR(link))) {
196                 bpf_dctcp__destroy(dctcp_skel);
197                 return;
198         }
199
200         do_test("bpf_dctcp");
201
202         bpf_link__destroy(link);
203         bpf_dctcp__destroy(dctcp_skel);
204 }
205
206 void test_bpf_tcp_ca(void)
207 {
208         if (test__start_subtest("dctcp"))
209                 test_dctcp();
210         if (test__start_subtest("cubic"))
211                 test_cubic();
212 }