Merge tag 'riscv/for-v5.3-rc4' of git://git.kernel.org/pub/scm/linux/kernel/git/riscv...
[sfrench/cifs-2.6.git] / tools / testing / selftests / net / tls.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <poll.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <unistd.h>
13
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
17
18 #include <sys/types.h>
19 #include <sys/sendfile.h>
20 #include <sys/socket.h>
21 #include <sys/stat.h>
22
23 #include "../kselftest_harness.h"
24
25 #define TLS_PAYLOAD_MAX_LEN 16384
26 #define SOL_TLS 282
27
28 #ifndef ENOTSUPP
29 #define ENOTSUPP 524
30 #endif
31
32 FIXTURE(tls_basic)
33 {
34         int fd, cfd;
35         bool notls;
36 };
37
38 FIXTURE_SETUP(tls_basic)
39 {
40         struct sockaddr_in addr;
41         socklen_t len;
42         int sfd, ret;
43
44         self->notls = false;
45         len = sizeof(addr);
46
47         addr.sin_family = AF_INET;
48         addr.sin_addr.s_addr = htonl(INADDR_ANY);
49         addr.sin_port = 0;
50
51         self->fd = socket(AF_INET, SOCK_STREAM, 0);
52         sfd = socket(AF_INET, SOCK_STREAM, 0);
53
54         ret = bind(sfd, &addr, sizeof(addr));
55         ASSERT_EQ(ret, 0);
56         ret = listen(sfd, 10);
57         ASSERT_EQ(ret, 0);
58
59         ret = getsockname(sfd, &addr, &len);
60         ASSERT_EQ(ret, 0);
61
62         ret = connect(self->fd, &addr, sizeof(addr));
63         ASSERT_EQ(ret, 0);
64
65         self->cfd = accept(sfd, &addr, &len);
66         ASSERT_GE(self->cfd, 0);
67
68         close(sfd);
69
70         ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
71         if (ret != 0) {
72                 ASSERT_EQ(errno, ENOENT);
73                 self->notls = true;
74                 printf("Failure setting TCP_ULP, testing without tls\n");
75                 return;
76         }
77
78         ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
79         ASSERT_EQ(ret, 0);
80 }
81
82 FIXTURE_TEARDOWN(tls_basic)
83 {
84         close(self->fd);
85         close(self->cfd);
86 }
87
88 /* Send some data through with ULP but no keys */
89 TEST_F(tls_basic, base_base)
90 {
91         char const *test_str = "test_read";
92         int send_len = 10;
93         char buf[10];
94
95         ASSERT_EQ(strlen(test_str) + 1, send_len);
96
97         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
98         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
99         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
100 };
101
102 FIXTURE(tls)
103 {
104         int fd, cfd;
105         bool notls;
106 };
107
108 FIXTURE_SETUP(tls)
109 {
110         struct tls12_crypto_info_aes_gcm_128 tls12;
111         struct sockaddr_in addr;
112         socklen_t len;
113         int sfd, ret;
114
115         self->notls = false;
116         len = sizeof(addr);
117
118         memset(&tls12, 0, sizeof(tls12));
119         tls12.info.version = TLS_1_3_VERSION;
120         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_128;
121
122         addr.sin_family = AF_INET;
123         addr.sin_addr.s_addr = htonl(INADDR_ANY);
124         addr.sin_port = 0;
125
126         self->fd = socket(AF_INET, SOCK_STREAM, 0);
127         sfd = socket(AF_INET, SOCK_STREAM, 0);
128
129         ret = bind(sfd, &addr, sizeof(addr));
130         ASSERT_EQ(ret, 0);
131         ret = listen(sfd, 10);
132         ASSERT_EQ(ret, 0);
133
134         ret = getsockname(sfd, &addr, &len);
135         ASSERT_EQ(ret, 0);
136
137         ret = connect(self->fd, &addr, sizeof(addr));
138         ASSERT_EQ(ret, 0);
139
140         ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
141         if (ret != 0) {
142                 self->notls = true;
143                 printf("Failure setting TCP_ULP, testing without tls\n");
144         }
145
146         if (!self->notls) {
147                 ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12,
148                                  sizeof(tls12));
149                 ASSERT_EQ(ret, 0);
150         }
151
152         self->cfd = accept(sfd, &addr, &len);
153         ASSERT_GE(self->cfd, 0);
154
155         if (!self->notls) {
156                 ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls",
157                                  sizeof("tls"));
158                 ASSERT_EQ(ret, 0);
159
160                 ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12,
161                                  sizeof(tls12));
162                 ASSERT_EQ(ret, 0);
163         }
164
165         close(sfd);
166 }
167
168 FIXTURE_TEARDOWN(tls)
169 {
170         close(self->fd);
171         close(self->cfd);
172 }
173
174 TEST_F(tls, sendfile)
175 {
176         int filefd = open("/proc/self/exe", O_RDONLY);
177         struct stat st;
178
179         EXPECT_GE(filefd, 0);
180         fstat(filefd, &st);
181         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
182 }
183
184 TEST_F(tls, send_then_sendfile)
185 {
186         int filefd = open("/proc/self/exe", O_RDONLY);
187         char const *test_str = "test_send";
188         int to_send = strlen(test_str) + 1;
189         char recv_buf[10];
190         struct stat st;
191         char *buf;
192
193         EXPECT_GE(filefd, 0);
194         fstat(filefd, &st);
195         buf = (char *)malloc(st.st_size);
196
197         EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
198         EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
199         EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
200
201         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
202         EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
203 }
204
205 TEST_F(tls, recv_max)
206 {
207         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
208         char recv_mem[TLS_PAYLOAD_MAX_LEN];
209         char buf[TLS_PAYLOAD_MAX_LEN];
210
211         EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
212         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
213         EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
214 }
215
216 TEST_F(tls, recv_small)
217 {
218         char const *test_str = "test_read";
219         int send_len = 10;
220         char buf[10];
221
222         send_len = strlen(test_str) + 1;
223         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
224         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
225         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
226 }
227
228 TEST_F(tls, msg_more)
229 {
230         char const *test_str = "test_read";
231         int send_len = 10;
232         char buf[10 * 2];
233
234         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
235         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
236         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
237         EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
238                   send_len * 2);
239         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
240 }
241
242 TEST_F(tls, msg_more_unsent)
243 {
244         char const *test_str = "test_read";
245         int send_len = 10;
246         char buf[10];
247
248         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
249         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
250 }
251
252 TEST_F(tls, sendmsg_single)
253 {
254         struct msghdr msg;
255
256         char const *test_str = "test_sendmsg";
257         size_t send_len = 13;
258         struct iovec vec;
259         char buf[13];
260
261         vec.iov_base = (char *)test_str;
262         vec.iov_len = send_len;
263         memset(&msg, 0, sizeof(struct msghdr));
264         msg.msg_iov = &vec;
265         msg.msg_iovlen = 1;
266         EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
267         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
268         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
269 }
270
271 TEST_F(tls, sendmsg_large)
272 {
273         void *mem = malloc(16384);
274         size_t send_len = 16384;
275         size_t sends = 128;
276         struct msghdr msg;
277         size_t recvs = 0;
278         size_t sent = 0;
279
280         memset(&msg, 0, sizeof(struct msghdr));
281         while (sent++ < sends) {
282                 struct iovec vec = { (void *)mem, send_len };
283
284                 msg.msg_iov = &vec;
285                 msg.msg_iovlen = 1;
286                 EXPECT_EQ(sendmsg(self->cfd, &msg, 0), send_len);
287         }
288
289         while (recvs++ < sends)
290                 EXPECT_NE(recv(self->fd, mem, send_len, 0), -1);
291
292         free(mem);
293 }
294
295 TEST_F(tls, sendmsg_multiple)
296 {
297         char const *test_str = "test_sendmsg_multiple";
298         struct iovec vec[5];
299         char *test_strs[5];
300         struct msghdr msg;
301         int total_len = 0;
302         int len_cmp = 0;
303         int iov_len = 5;
304         char *buf;
305         int i;
306
307         memset(&msg, 0, sizeof(struct msghdr));
308         for (i = 0; i < iov_len; i++) {
309                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
310                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
311                 vec[i].iov_base = (void *)test_strs[i];
312                 vec[i].iov_len = strlen(test_strs[i]) + 1;
313                 total_len += vec[i].iov_len;
314         }
315         msg.msg_iov = vec;
316         msg.msg_iovlen = iov_len;
317
318         EXPECT_EQ(sendmsg(self->cfd, &msg, 0), total_len);
319         buf = malloc(total_len);
320         EXPECT_NE(recv(self->fd, buf, total_len, 0), -1);
321         for (i = 0; i < iov_len; i++) {
322                 EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
323                                  strlen(test_strs[i])),
324                           0);
325                 len_cmp += strlen(buf + len_cmp) + 1;
326         }
327         for (i = 0; i < iov_len; i++)
328                 free(test_strs[i]);
329         free(buf);
330 }
331
332 TEST_F(tls, sendmsg_multiple_stress)
333 {
334         char const *test_str = "abcdefghijklmno";
335         struct iovec vec[1024];
336         char *test_strs[1024];
337         int iov_len = 1024;
338         int total_len = 0;
339         char buf[1 << 14];
340         struct msghdr msg;
341         int len_cmp = 0;
342         int i;
343
344         memset(&msg, 0, sizeof(struct msghdr));
345         for (i = 0; i < iov_len; i++) {
346                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
347                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
348                 vec[i].iov_base = (void *)test_strs[i];
349                 vec[i].iov_len = strlen(test_strs[i]) + 1;
350                 total_len += vec[i].iov_len;
351         }
352         msg.msg_iov = vec;
353         msg.msg_iovlen = iov_len;
354
355         EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
356         EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
357
358         for (i = 0; i < iov_len; i++)
359                 len_cmp += strlen(buf + len_cmp) + 1;
360
361         for (i = 0; i < iov_len; i++)
362                 free(test_strs[i]);
363 }
364
365 TEST_F(tls, splice_from_pipe)
366 {
367         int send_len = TLS_PAYLOAD_MAX_LEN;
368         char mem_send[TLS_PAYLOAD_MAX_LEN];
369         char mem_recv[TLS_PAYLOAD_MAX_LEN];
370         int p[2];
371
372         ASSERT_GE(pipe(p), 0);
373         EXPECT_GE(write(p[1], mem_send, send_len), 0);
374         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
375         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
376         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
377 }
378
379 TEST_F(tls, splice_from_pipe2)
380 {
381         int send_len = 16000;
382         char mem_send[16000];
383         char mem_recv[16000];
384         int p2[2];
385         int p[2];
386
387         ASSERT_GE(pipe(p), 0);
388         ASSERT_GE(pipe(p2), 0);
389         EXPECT_GE(write(p[1], mem_send, 8000), 0);
390         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, 8000, 0), 0);
391         EXPECT_GE(write(p2[1], mem_send + 8000, 8000), 0);
392         EXPECT_GE(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 0);
393         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
394         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
395 }
396
397 TEST_F(tls, send_and_splice)
398 {
399         int send_len = TLS_PAYLOAD_MAX_LEN;
400         char mem_send[TLS_PAYLOAD_MAX_LEN];
401         char mem_recv[TLS_PAYLOAD_MAX_LEN];
402         char const *test_str = "test_read";
403         int send_len2 = 10;
404         char buf[10];
405         int p[2];
406
407         ASSERT_GE(pipe(p), 0);
408         EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
409         EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
410         EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
411
412         EXPECT_GE(write(p[1], mem_send, send_len), send_len);
413         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
414
415         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
416         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
417 }
418
419 TEST_F(tls, splice_to_pipe)
420 {
421         int send_len = TLS_PAYLOAD_MAX_LEN;
422         char mem_send[TLS_PAYLOAD_MAX_LEN];
423         char mem_recv[TLS_PAYLOAD_MAX_LEN];
424         int p[2];
425
426         ASSERT_GE(pipe(p), 0);
427         EXPECT_GE(send(self->fd, mem_send, send_len, 0), 0);
428         EXPECT_GE(splice(self->cfd, NULL, p[1], NULL, send_len, 0), 0);
429         EXPECT_GE(read(p[0], mem_recv, send_len), 0);
430         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
431 }
432
433 TEST_F(tls, recvmsg_single)
434 {
435         char const *test_str = "test_recvmsg_single";
436         int send_len = strlen(test_str) + 1;
437         char buf[20];
438         struct msghdr hdr;
439         struct iovec vec;
440
441         memset(&hdr, 0, sizeof(hdr));
442         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
443         vec.iov_base = (char *)buf;
444         vec.iov_len = send_len;
445         hdr.msg_iovlen = 1;
446         hdr.msg_iov = &vec;
447         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
448         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
449 }
450
451 TEST_F(tls, recvmsg_single_max)
452 {
453         int send_len = TLS_PAYLOAD_MAX_LEN;
454         char send_mem[TLS_PAYLOAD_MAX_LEN];
455         char recv_mem[TLS_PAYLOAD_MAX_LEN];
456         struct iovec vec;
457         struct msghdr hdr;
458
459         EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
460         vec.iov_base = (char *)recv_mem;
461         vec.iov_len = TLS_PAYLOAD_MAX_LEN;
462
463         hdr.msg_iovlen = 1;
464         hdr.msg_iov = &vec;
465         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
466         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
467 }
468
469 TEST_F(tls, recvmsg_multiple)
470 {
471         unsigned int msg_iovlen = 1024;
472         unsigned int len_compared = 0;
473         struct iovec vec[1024];
474         char *iov_base[1024];
475         unsigned int iov_len = 16;
476         int send_len = 1 << 14;
477         char buf[1 << 14];
478         struct msghdr hdr;
479         int i;
480
481         EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
482         for (i = 0; i < msg_iovlen; i++) {
483                 iov_base[i] = (char *)malloc(iov_len);
484                 vec[i].iov_base = iov_base[i];
485                 vec[i].iov_len = iov_len;
486         }
487
488         hdr.msg_iovlen = msg_iovlen;
489         hdr.msg_iov = vec;
490         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
491         for (i = 0; i < msg_iovlen; i++)
492                 len_compared += iov_len;
493
494         for (i = 0; i < msg_iovlen; i++)
495                 free(iov_base[i]);
496 }
497
498 TEST_F(tls, single_send_multiple_recv)
499 {
500         unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
501         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
502         char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
503         char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
504
505         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
506         memset(recv_mem, 0, total_len);
507
508         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
509         EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
510         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
511 }
512
513 TEST_F(tls, multiple_send_single_recv)
514 {
515         unsigned int total_len = 2 * 10;
516         unsigned int send_len = 10;
517         char recv_mem[2 * 10];
518         char send_mem[10];
519
520         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
521         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
522         memset(recv_mem, 0, total_len);
523         EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
524
525         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
526         EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
527 }
528
529 TEST_F(tls, single_send_multiple_recv_non_align)
530 {
531         const unsigned int total_len = 15;
532         const unsigned int recv_len = 10;
533         char recv_mem[recv_len * 2];
534         char send_mem[total_len];
535
536         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
537         memset(recv_mem, 0, total_len);
538
539         EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
540         EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
541         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
542 }
543
544 TEST_F(tls, recv_partial)
545 {
546         char const *test_str = "test_read_partial";
547         char const *test_str_first = "test_read";
548         char const *test_str_second = "_partial";
549         int send_len = strlen(test_str) + 1;
550         char recv_mem[18];
551
552         memset(recv_mem, 0, sizeof(recv_mem));
553         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
554         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_first),
555                        MSG_WAITALL), -1);
556         EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
557         memset(recv_mem, 0, sizeof(recv_mem));
558         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_second),
559                        MSG_WAITALL), -1);
560         EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
561                   0);
562 }
563
564 TEST_F(tls, recv_nonblock)
565 {
566         char buf[4096];
567         bool err;
568
569         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
570         err = (errno == EAGAIN || errno == EWOULDBLOCK);
571         EXPECT_EQ(err, true);
572 }
573
574 TEST_F(tls, recv_peek)
575 {
576         char const *test_str = "test_read_peek";
577         int send_len = strlen(test_str) + 1;
578         char buf[15];
579
580         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
581         EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
582         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
583         memset(buf, 0, sizeof(buf));
584         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
585         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
586 }
587
588 TEST_F(tls, recv_peek_multiple)
589 {
590         char const *test_str = "test_read_peek";
591         int send_len = strlen(test_str) + 1;
592         unsigned int num_peeks = 100;
593         char buf[15];
594         int i;
595
596         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
597         for (i = 0; i < num_peeks; i++) {
598                 EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
599                 EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
600                 memset(buf, 0, sizeof(buf));
601         }
602         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
603         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
604 }
605
606 TEST_F(tls, recv_peek_multiple_records)
607 {
608         char const *test_str = "test_read_peek_mult_recs";
609         char const *test_str_first = "test_read_peek";
610         char const *test_str_second = "_mult_recs";
611         int len;
612         char buf[64];
613
614         len = strlen(test_str_first);
615         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
616
617         len = strlen(test_str_second) + 1;
618         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
619
620         len = strlen(test_str_first);
621         memset(buf, 0, len);
622         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
623
624         /* MSG_PEEK can only peek into the current record. */
625         len = strlen(test_str_first);
626         EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
627
628         len = strlen(test_str) + 1;
629         memset(buf, 0, len);
630         EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
631
632         /* Non-MSG_PEEK will advance strparser (and therefore record)
633          * however.
634          */
635         len = strlen(test_str) + 1;
636         EXPECT_EQ(memcmp(test_str, buf, len), 0);
637
638         /* MSG_MORE will hold current record open, so later MSG_PEEK
639          * will see everything.
640          */
641         len = strlen(test_str_first);
642         EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
643
644         len = strlen(test_str_second) + 1;
645         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
646
647         len = strlen(test_str) + 1;
648         memset(buf, 0, len);
649         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
650
651         len = strlen(test_str) + 1;
652         EXPECT_EQ(memcmp(test_str, buf, len), 0);
653 }
654
655 TEST_F(tls, recv_peek_large_buf_mult_recs)
656 {
657         char const *test_str = "test_read_peek_mult_recs";
658         char const *test_str_first = "test_read_peek";
659         char const *test_str_second = "_mult_recs";
660         int len;
661         char buf[64];
662
663         len = strlen(test_str_first);
664         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
665
666         len = strlen(test_str_second) + 1;
667         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
668
669         len = strlen(test_str) + 1;
670         memset(buf, 0, len);
671         EXPECT_NE((len = recv(self->cfd, buf, len,
672                               MSG_PEEK | MSG_WAITALL)), -1);
673         len = strlen(test_str) + 1;
674         EXPECT_EQ(memcmp(test_str, buf, len), 0);
675 }
676
677 TEST_F(tls, recv_lowat)
678 {
679         char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
680         char recv_mem[20];
681         int lowat = 8;
682
683         EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
684         EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
685
686         memset(recv_mem, 0, 20);
687         EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
688                              &lowat, sizeof(lowat)), 0);
689         EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
690         EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
691         EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
692
693         EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
694         EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
695 }
696
697 TEST_F(tls, bidir)
698 {
699         char const *test_str = "test_read";
700         int send_len = 10;
701         char buf[10];
702         int ret;
703
704         if (!self->notls) {
705                 struct tls12_crypto_info_aes_gcm_128 tls12;
706
707                 memset(&tls12, 0, sizeof(tls12));
708                 tls12.info.version = TLS_1_3_VERSION;
709                 tls12.info.cipher_type = TLS_CIPHER_AES_GCM_128;
710
711                 ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
712                                  sizeof(tls12));
713                 ASSERT_EQ(ret, 0);
714
715                 ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
716                                  sizeof(tls12));
717                 ASSERT_EQ(ret, 0);
718         }
719
720         ASSERT_EQ(strlen(test_str) + 1, send_len);
721
722         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
723         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
724         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
725
726         memset(buf, 0, sizeof(buf));
727
728         EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
729         EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
730         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
731 };
732
733 TEST_F(tls, pollin)
734 {
735         char const *test_str = "test_poll";
736         struct pollfd fd = { 0, 0, 0 };
737         char buf[10];
738         int send_len = 10;
739
740         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
741         fd.fd = self->cfd;
742         fd.events = POLLIN;
743
744         EXPECT_EQ(poll(&fd, 1, 20), 1);
745         EXPECT_EQ(fd.revents & POLLIN, 1);
746         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
747         /* Test timing out */
748         EXPECT_EQ(poll(&fd, 1, 20), 0);
749 }
750
751 TEST_F(tls, poll_wait)
752 {
753         char const *test_str = "test_poll_wait";
754         int send_len = strlen(test_str) + 1;
755         struct pollfd fd = { 0, 0, 0 };
756         char recv_mem[15];
757
758         fd.fd = self->cfd;
759         fd.events = POLLIN;
760         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
761         /* Set timeout to inf. secs */
762         EXPECT_EQ(poll(&fd, 1, -1), 1);
763         EXPECT_EQ(fd.revents & POLLIN, 1);
764         EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
765 }
766
767 TEST_F(tls, poll_wait_split)
768 {
769         struct pollfd fd = { 0, 0, 0 };
770         char send_mem[20] = {};
771         char recv_mem[15];
772
773         fd.fd = self->cfd;
774         fd.events = POLLIN;
775         /* Send 20 bytes */
776         EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
777                   sizeof(send_mem));
778         /* Poll with inf. timeout */
779         EXPECT_EQ(poll(&fd, 1, -1), 1);
780         EXPECT_EQ(fd.revents & POLLIN, 1);
781         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
782                   sizeof(recv_mem));
783
784         /* Now the remaining 5 bytes of record data are in TLS ULP */
785         fd.fd = self->cfd;
786         fd.events = POLLIN;
787         EXPECT_EQ(poll(&fd, 1, -1), 1);
788         EXPECT_EQ(fd.revents & POLLIN, 1);
789         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
790                   sizeof(send_mem) - sizeof(recv_mem));
791 }
792
793 TEST_F(tls, blocking)
794 {
795         size_t data = 100000;
796         int res = fork();
797
798         EXPECT_NE(res, -1);
799
800         if (res) {
801                 /* parent */
802                 size_t left = data;
803                 char buf[16384];
804                 int status;
805                 int pid2;
806
807                 while (left) {
808                         int res = send(self->fd, buf,
809                                        left > 16384 ? 16384 : left, 0);
810
811                         EXPECT_GE(res, 0);
812                         left -= res;
813                 }
814
815                 pid2 = wait(&status);
816                 EXPECT_EQ(status, 0);
817                 EXPECT_EQ(res, pid2);
818         } else {
819                 /* child */
820                 size_t left = data;
821                 char buf[16384];
822
823                 while (left) {
824                         int res = recv(self->cfd, buf,
825                                        left > 16384 ? 16384 : left, 0);
826
827                         EXPECT_GE(res, 0);
828                         left -= res;
829                 }
830         }
831 }
832
833 TEST_F(tls, nonblocking)
834 {
835         size_t data = 100000;
836         int sendbuf = 100;
837         int flags;
838         int res;
839
840         flags = fcntl(self->fd, F_GETFL, 0);
841         fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
842         fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
843
844         /* Ensure nonblocking behavior by imposing a small send
845          * buffer.
846          */
847         EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
848                              &sendbuf, sizeof(sendbuf)), 0);
849
850         res = fork();
851         EXPECT_NE(res, -1);
852
853         if (res) {
854                 /* parent */
855                 bool eagain = false;
856                 size_t left = data;
857                 char buf[16384];
858                 int status;
859                 int pid2;
860
861                 while (left) {
862                         int res = send(self->fd, buf,
863                                        left > 16384 ? 16384 : left, 0);
864
865                         if (res == -1 && errno == EAGAIN) {
866                                 eagain = true;
867                                 usleep(10000);
868                                 continue;
869                         }
870                         EXPECT_GE(res, 0);
871                         left -= res;
872                 }
873
874                 EXPECT_TRUE(eagain);
875                 pid2 = wait(&status);
876
877                 EXPECT_EQ(status, 0);
878                 EXPECT_EQ(res, pid2);
879         } else {
880                 /* child */
881                 bool eagain = false;
882                 size_t left = data;
883                 char buf[16384];
884
885                 while (left) {
886                         int res = recv(self->cfd, buf,
887                                        left > 16384 ? 16384 : left, 0);
888
889                         if (res == -1 && errno == EAGAIN) {
890                                 eagain = true;
891                                 usleep(10000);
892                                 continue;
893                         }
894                         EXPECT_GE(res, 0);
895                         left -= res;
896                 }
897                 EXPECT_TRUE(eagain);
898         }
899 }
900
901 TEST_F(tls, control_msg)
902 {
903         if (self->notls)
904                 return;
905
906         char cbuf[CMSG_SPACE(sizeof(char))];
907         char const *test_str = "test_read";
908         int cmsg_len = sizeof(char);
909         char record_type = 100;
910         struct cmsghdr *cmsg;
911         struct msghdr msg;
912         int send_len = 10;
913         struct iovec vec;
914         char buf[10];
915
916         vec.iov_base = (char *)test_str;
917         vec.iov_len = 10;
918         memset(&msg, 0, sizeof(struct msghdr));
919         msg.msg_iov = &vec;
920         msg.msg_iovlen = 1;
921         msg.msg_control = cbuf;
922         msg.msg_controllen = sizeof(cbuf);
923         cmsg = CMSG_FIRSTHDR(&msg);
924         cmsg->cmsg_level = SOL_TLS;
925         /* test sending non-record types. */
926         cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
927         cmsg->cmsg_len = CMSG_LEN(cmsg_len);
928         *CMSG_DATA(cmsg) = record_type;
929         msg.msg_controllen = cmsg->cmsg_len;
930
931         EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
932         /* Should fail because we didn't provide a control message */
933         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
934
935         vec.iov_base = buf;
936         EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL | MSG_PEEK), send_len);
937
938         cmsg = CMSG_FIRSTHDR(&msg);
939         EXPECT_NE(cmsg, NULL);
940         EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
941         EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
942         record_type = *((unsigned char *)CMSG_DATA(cmsg));
943         EXPECT_EQ(record_type, 100);
944         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
945
946         /* Recv the message again without MSG_PEEK */
947         record_type = 0;
948         memset(buf, 0, sizeof(buf));
949
950         EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL), send_len);
951         cmsg = CMSG_FIRSTHDR(&msg);
952         EXPECT_NE(cmsg, NULL);
953         EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
954         EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
955         record_type = *((unsigned char *)CMSG_DATA(cmsg));
956         EXPECT_EQ(record_type, 100);
957         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
958 }
959
960 TEST_F(tls, shutdown)
961 {
962         char const *test_str = "test_read";
963         int send_len = 10;
964         char buf[10];
965
966         ASSERT_EQ(strlen(test_str) + 1, send_len);
967
968         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
969         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
970         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
971
972         shutdown(self->fd, SHUT_RDWR);
973         shutdown(self->cfd, SHUT_RDWR);
974 }
975
976 TEST_F(tls, shutdown_unsent)
977 {
978         char const *test_str = "test_read";
979         int send_len = 10;
980
981         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
982
983         shutdown(self->fd, SHUT_RDWR);
984         shutdown(self->cfd, SHUT_RDWR);
985 }
986
987 TEST_F(tls, shutdown_reuse)
988 {
989         struct sockaddr_in addr;
990         int ret;
991
992         shutdown(self->fd, SHUT_RDWR);
993         shutdown(self->cfd, SHUT_RDWR);
994         close(self->cfd);
995
996         addr.sin_family = AF_INET;
997         addr.sin_addr.s_addr = htonl(INADDR_ANY);
998         addr.sin_port = 0;
999
1000         ret = bind(self->fd, &addr, sizeof(addr));
1001         EXPECT_EQ(ret, 0);
1002         ret = listen(self->fd, 10);
1003         EXPECT_EQ(ret, -1);
1004         EXPECT_EQ(errno, EINVAL);
1005
1006         ret = connect(self->fd, &addr, sizeof(addr));
1007         EXPECT_EQ(ret, -1);
1008         EXPECT_EQ(errno, EISCONN);
1009 }
1010
1011 TEST(non_established) {
1012         struct tls12_crypto_info_aes_gcm_256 tls12;
1013         struct sockaddr_in addr;
1014         int sfd, ret, fd;
1015         socklen_t len;
1016
1017         len = sizeof(addr);
1018
1019         memset(&tls12, 0, sizeof(tls12));
1020         tls12.info.version = TLS_1_2_VERSION;
1021         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1022
1023         addr.sin_family = AF_INET;
1024         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1025         addr.sin_port = 0;
1026
1027         fd = socket(AF_INET, SOCK_STREAM, 0);
1028         sfd = socket(AF_INET, SOCK_STREAM, 0);
1029
1030         ret = bind(sfd, &addr, sizeof(addr));
1031         ASSERT_EQ(ret, 0);
1032         ret = listen(sfd, 10);
1033         ASSERT_EQ(ret, 0);
1034
1035         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1036         EXPECT_EQ(ret, -1);
1037         /* TLS ULP not supported */
1038         if (errno == ENOENT)
1039                 return;
1040         EXPECT_EQ(errno, ENOTSUPP);
1041
1042         ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1043         EXPECT_EQ(ret, -1);
1044         EXPECT_EQ(errno, ENOTSUPP);
1045
1046         ret = getsockname(sfd, &addr, &len);
1047         ASSERT_EQ(ret, 0);
1048
1049         ret = connect(fd, &addr, sizeof(addr));
1050         ASSERT_EQ(ret, 0);
1051
1052         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1053         ASSERT_EQ(ret, 0);
1054
1055         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1056         EXPECT_EQ(ret, -1);
1057         EXPECT_EQ(errno, EEXIST);
1058
1059         close(fd);
1060         close(sfd);
1061 }
1062
1063 TEST(keysizes) {
1064         struct tls12_crypto_info_aes_gcm_256 tls12;
1065         struct sockaddr_in addr;
1066         int sfd, ret, fd, cfd;
1067         socklen_t len;
1068         bool notls;
1069
1070         notls = false;
1071         len = sizeof(addr);
1072
1073         memset(&tls12, 0, sizeof(tls12));
1074         tls12.info.version = TLS_1_2_VERSION;
1075         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1076
1077         addr.sin_family = AF_INET;
1078         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1079         addr.sin_port = 0;
1080
1081         fd = socket(AF_INET, SOCK_STREAM, 0);
1082         sfd = socket(AF_INET, SOCK_STREAM, 0);
1083
1084         ret = bind(sfd, &addr, sizeof(addr));
1085         ASSERT_EQ(ret, 0);
1086         ret = listen(sfd, 10);
1087         ASSERT_EQ(ret, 0);
1088
1089         ret = getsockname(sfd, &addr, &len);
1090         ASSERT_EQ(ret, 0);
1091
1092         ret = connect(fd, &addr, sizeof(addr));
1093         ASSERT_EQ(ret, 0);
1094
1095         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1096         if (ret != 0) {
1097                 notls = true;
1098                 printf("Failure setting TCP_ULP, testing without tls\n");
1099         }
1100
1101         if (!notls) {
1102                 ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
1103                                  sizeof(tls12));
1104                 EXPECT_EQ(ret, 0);
1105         }
1106
1107         cfd = accept(sfd, &addr, &len);
1108         ASSERT_GE(cfd, 0);
1109
1110         if (!notls) {
1111                 ret = setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls",
1112                                  sizeof("tls"));
1113                 EXPECT_EQ(ret, 0);
1114
1115                 ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
1116                                  sizeof(tls12));
1117                 EXPECT_EQ(ret, 0);
1118         }
1119
1120         close(sfd);
1121         close(fd);
1122         close(cfd);
1123 }
1124
1125 TEST(tls12) {
1126         int fd, cfd;
1127         bool notls;
1128
1129         struct tls12_crypto_info_aes_gcm_128 tls12;
1130         struct sockaddr_in addr;
1131         socklen_t len;
1132         int sfd, ret;
1133
1134         notls = false;
1135         len = sizeof(addr);
1136
1137         memset(&tls12, 0, sizeof(tls12));
1138         tls12.info.version = TLS_1_2_VERSION;
1139         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_128;
1140
1141         addr.sin_family = AF_INET;
1142         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1143         addr.sin_port = 0;
1144
1145         fd = socket(AF_INET, SOCK_STREAM, 0);
1146         sfd = socket(AF_INET, SOCK_STREAM, 0);
1147
1148         ret = bind(sfd, &addr, sizeof(addr));
1149         ASSERT_EQ(ret, 0);
1150         ret = listen(sfd, 10);
1151         ASSERT_EQ(ret, 0);
1152
1153         ret = getsockname(sfd, &addr, &len);
1154         ASSERT_EQ(ret, 0);
1155
1156         ret = connect(fd, &addr, sizeof(addr));
1157         ASSERT_EQ(ret, 0);
1158
1159         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1160         if (ret != 0) {
1161                 notls = true;
1162                 printf("Failure setting TCP_ULP, testing without tls\n");
1163         }
1164
1165         if (!notls) {
1166                 ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
1167                                  sizeof(tls12));
1168                 ASSERT_EQ(ret, 0);
1169         }
1170
1171         cfd = accept(sfd, &addr, &len);
1172         ASSERT_GE(cfd, 0);
1173
1174         if (!notls) {
1175                 ret = setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls",
1176                                  sizeof("tls"));
1177                 ASSERT_EQ(ret, 0);
1178
1179                 ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
1180                                  sizeof(tls12));
1181                 ASSERT_EQ(ret, 0);
1182         }
1183
1184         close(sfd);
1185
1186         char const *test_str = "test_read";
1187         int send_len = 10;
1188         char buf[10];
1189
1190         send_len = strlen(test_str) + 1;
1191         EXPECT_EQ(send(fd, test_str, send_len, 0), send_len);
1192         EXPECT_NE(recv(cfd, buf, send_len, 0), -1);
1193         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1194
1195         close(fd);
1196         close(cfd);
1197 }
1198
1199 TEST_HARNESS_MAIN