selftests/tls: add a test for ULP but no keys
[sfrench/cifs-2.6.git] / tools / testing / selftests / net / tls.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <poll.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <unistd.h>
13
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
17
18 #include <sys/types.h>
19 #include <sys/sendfile.h>
20 #include <sys/socket.h>
21 #include <sys/stat.h>
22
23 #include "../kselftest_harness.h"
24
25 #define TLS_PAYLOAD_MAX_LEN 16384
26 #define SOL_TLS 282
27
28 #ifndef ENOTSUPP
29 #define ENOTSUPP 524
30 #endif
31
32 FIXTURE(tls_basic)
33 {
34         int fd, cfd;
35         bool notls;
36 };
37
38 FIXTURE_SETUP(tls_basic)
39 {
40         struct sockaddr_in addr;
41         socklen_t len;
42         int sfd, ret;
43
44         self->notls = false;
45         len = sizeof(addr);
46
47         addr.sin_family = AF_INET;
48         addr.sin_addr.s_addr = htonl(INADDR_ANY);
49         addr.sin_port = 0;
50
51         self->fd = socket(AF_INET, SOCK_STREAM, 0);
52         sfd = socket(AF_INET, SOCK_STREAM, 0);
53
54         ret = bind(sfd, &addr, sizeof(addr));
55         ASSERT_EQ(ret, 0);
56         ret = listen(sfd, 10);
57         ASSERT_EQ(ret, 0);
58
59         ret = getsockname(sfd, &addr, &len);
60         ASSERT_EQ(ret, 0);
61
62         ret = connect(self->fd, &addr, sizeof(addr));
63         ASSERT_EQ(ret, 0);
64
65         self->cfd = accept(sfd, &addr, &len);
66         ASSERT_GE(self->cfd, 0);
67
68         close(sfd);
69
70         ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
71         if (ret != 0) {
72                 ASSERT_EQ(errno, ENOTSUPP);
73                 self->notls = true;
74                 printf("Failure setting TCP_ULP, testing without tls\n");
75                 return;
76         }
77
78         ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
79         ASSERT_EQ(ret, 0);
80 }
81
82 FIXTURE_TEARDOWN(tls_basic)
83 {
84         close(self->fd);
85         close(self->cfd);
86 }
87
88 /* Send some data through with ULP but no keys */
89 TEST_F(tls_basic, base_base)
90 {
91         char const *test_str = "test_read";
92         int send_len = 10;
93         char buf[10];
94
95         ASSERT_EQ(strlen(test_str) + 1, send_len);
96
97         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
98         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
99         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
100 };
101
102 FIXTURE(tls)
103 {
104         int fd, cfd;
105         bool notls;
106 };
107
108 FIXTURE_SETUP(tls)
109 {
110         struct tls12_crypto_info_aes_gcm_128 tls12;
111         struct sockaddr_in addr;
112         socklen_t len;
113         int sfd, ret;
114
115         self->notls = false;
116         len = sizeof(addr);
117
118         memset(&tls12, 0, sizeof(tls12));
119         tls12.info.version = TLS_1_3_VERSION;
120         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_128;
121
122         addr.sin_family = AF_INET;
123         addr.sin_addr.s_addr = htonl(INADDR_ANY);
124         addr.sin_port = 0;
125
126         self->fd = socket(AF_INET, SOCK_STREAM, 0);
127         sfd = socket(AF_INET, SOCK_STREAM, 0);
128
129         ret = bind(sfd, &addr, sizeof(addr));
130         ASSERT_EQ(ret, 0);
131         ret = listen(sfd, 10);
132         ASSERT_EQ(ret, 0);
133
134         ret = getsockname(sfd, &addr, &len);
135         ASSERT_EQ(ret, 0);
136
137         ret = connect(self->fd, &addr, sizeof(addr));
138         ASSERT_EQ(ret, 0);
139
140         ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
141         if (ret != 0) {
142                 self->notls = true;
143                 printf("Failure setting TCP_ULP, testing without tls\n");
144         }
145
146         if (!self->notls) {
147                 ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12,
148                                  sizeof(tls12));
149                 ASSERT_EQ(ret, 0);
150         }
151
152         self->cfd = accept(sfd, &addr, &len);
153         ASSERT_GE(self->cfd, 0);
154
155         if (!self->notls) {
156                 ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls",
157                                  sizeof("tls"));
158                 ASSERT_EQ(ret, 0);
159
160                 ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12,
161                                  sizeof(tls12));
162                 ASSERT_EQ(ret, 0);
163         }
164
165         close(sfd);
166 }
167
168 FIXTURE_TEARDOWN(tls)
169 {
170         close(self->fd);
171         close(self->cfd);
172 }
173
174 TEST_F(tls, sendfile)
175 {
176         int filefd = open("/proc/self/exe", O_RDONLY);
177         struct stat st;
178
179         EXPECT_GE(filefd, 0);
180         fstat(filefd, &st);
181         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
182 }
183
184 TEST_F(tls, send_then_sendfile)
185 {
186         int filefd = open("/proc/self/exe", O_RDONLY);
187         char const *test_str = "test_send";
188         int to_send = strlen(test_str) + 1;
189         char recv_buf[10];
190         struct stat st;
191         char *buf;
192
193         EXPECT_GE(filefd, 0);
194         fstat(filefd, &st);
195         buf = (char *)malloc(st.st_size);
196
197         EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
198         EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
199         EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
200
201         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
202         EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
203 }
204
205 TEST_F(tls, recv_max)
206 {
207         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
208         char recv_mem[TLS_PAYLOAD_MAX_LEN];
209         char buf[TLS_PAYLOAD_MAX_LEN];
210
211         EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
212         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
213         EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
214 }
215
216 TEST_F(tls, recv_small)
217 {
218         char const *test_str = "test_read";
219         int send_len = 10;
220         char buf[10];
221
222         send_len = strlen(test_str) + 1;
223         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
224         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
225         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
226 }
227
228 TEST_F(tls, msg_more)
229 {
230         char const *test_str = "test_read";
231         int send_len = 10;
232         char buf[10 * 2];
233
234         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
235         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
236         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
237         EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
238                   send_len * 2);
239         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
240 }
241
242 TEST_F(tls, sendmsg_single)
243 {
244         struct msghdr msg;
245
246         char const *test_str = "test_sendmsg";
247         size_t send_len = 13;
248         struct iovec vec;
249         char buf[13];
250
251         vec.iov_base = (char *)test_str;
252         vec.iov_len = send_len;
253         memset(&msg, 0, sizeof(struct msghdr));
254         msg.msg_iov = &vec;
255         msg.msg_iovlen = 1;
256         EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
257         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
258         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
259 }
260
261 TEST_F(tls, sendmsg_large)
262 {
263         void *mem = malloc(16384);
264         size_t send_len = 16384;
265         size_t sends = 128;
266         struct msghdr msg;
267         size_t recvs = 0;
268         size_t sent = 0;
269
270         memset(&msg, 0, sizeof(struct msghdr));
271         while (sent++ < sends) {
272                 struct iovec vec = { (void *)mem, send_len };
273
274                 msg.msg_iov = &vec;
275                 msg.msg_iovlen = 1;
276                 EXPECT_EQ(sendmsg(self->cfd, &msg, 0), send_len);
277         }
278
279         while (recvs++ < sends)
280                 EXPECT_NE(recv(self->fd, mem, send_len, 0), -1);
281
282         free(mem);
283 }
284
285 TEST_F(tls, sendmsg_multiple)
286 {
287         char const *test_str = "test_sendmsg_multiple";
288         struct iovec vec[5];
289         char *test_strs[5];
290         struct msghdr msg;
291         int total_len = 0;
292         int len_cmp = 0;
293         int iov_len = 5;
294         char *buf;
295         int i;
296
297         memset(&msg, 0, sizeof(struct msghdr));
298         for (i = 0; i < iov_len; i++) {
299                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
300                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
301                 vec[i].iov_base = (void *)test_strs[i];
302                 vec[i].iov_len = strlen(test_strs[i]) + 1;
303                 total_len += vec[i].iov_len;
304         }
305         msg.msg_iov = vec;
306         msg.msg_iovlen = iov_len;
307
308         EXPECT_EQ(sendmsg(self->cfd, &msg, 0), total_len);
309         buf = malloc(total_len);
310         EXPECT_NE(recv(self->fd, buf, total_len, 0), -1);
311         for (i = 0; i < iov_len; i++) {
312                 EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
313                                  strlen(test_strs[i])),
314                           0);
315                 len_cmp += strlen(buf + len_cmp) + 1;
316         }
317         for (i = 0; i < iov_len; i++)
318                 free(test_strs[i]);
319         free(buf);
320 }
321
322 TEST_F(tls, sendmsg_multiple_stress)
323 {
324         char const *test_str = "abcdefghijklmno";
325         struct iovec vec[1024];
326         char *test_strs[1024];
327         int iov_len = 1024;
328         int total_len = 0;
329         char buf[1 << 14];
330         struct msghdr msg;
331         int len_cmp = 0;
332         int i;
333
334         memset(&msg, 0, sizeof(struct msghdr));
335         for (i = 0; i < iov_len; i++) {
336                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
337                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
338                 vec[i].iov_base = (void *)test_strs[i];
339                 vec[i].iov_len = strlen(test_strs[i]) + 1;
340                 total_len += vec[i].iov_len;
341         }
342         msg.msg_iov = vec;
343         msg.msg_iovlen = iov_len;
344
345         EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
346         EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
347
348         for (i = 0; i < iov_len; i++)
349                 len_cmp += strlen(buf + len_cmp) + 1;
350
351         for (i = 0; i < iov_len; i++)
352                 free(test_strs[i]);
353 }
354
355 TEST_F(tls, splice_from_pipe)
356 {
357         int send_len = TLS_PAYLOAD_MAX_LEN;
358         char mem_send[TLS_PAYLOAD_MAX_LEN];
359         char mem_recv[TLS_PAYLOAD_MAX_LEN];
360         int p[2];
361
362         ASSERT_GE(pipe(p), 0);
363         EXPECT_GE(write(p[1], mem_send, send_len), 0);
364         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
365         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
366         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
367 }
368
369 TEST_F(tls, splice_from_pipe2)
370 {
371         int send_len = 16000;
372         char mem_send[16000];
373         char mem_recv[16000];
374         int p2[2];
375         int p[2];
376
377         ASSERT_GE(pipe(p), 0);
378         ASSERT_GE(pipe(p2), 0);
379         EXPECT_GE(write(p[1], mem_send, 8000), 0);
380         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, 8000, 0), 0);
381         EXPECT_GE(write(p2[1], mem_send + 8000, 8000), 0);
382         EXPECT_GE(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 0);
383         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
384         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
385 }
386
387 TEST_F(tls, send_and_splice)
388 {
389         int send_len = TLS_PAYLOAD_MAX_LEN;
390         char mem_send[TLS_PAYLOAD_MAX_LEN];
391         char mem_recv[TLS_PAYLOAD_MAX_LEN];
392         char const *test_str = "test_read";
393         int send_len2 = 10;
394         char buf[10];
395         int p[2];
396
397         ASSERT_GE(pipe(p), 0);
398         EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
399         EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
400         EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
401
402         EXPECT_GE(write(p[1], mem_send, send_len), send_len);
403         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
404
405         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
406         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
407 }
408
409 TEST_F(tls, splice_to_pipe)
410 {
411         int send_len = TLS_PAYLOAD_MAX_LEN;
412         char mem_send[TLS_PAYLOAD_MAX_LEN];
413         char mem_recv[TLS_PAYLOAD_MAX_LEN];
414         int p[2];
415
416         ASSERT_GE(pipe(p), 0);
417         EXPECT_GE(send(self->fd, mem_send, send_len, 0), 0);
418         EXPECT_GE(splice(self->cfd, NULL, p[1], NULL, send_len, 0), 0);
419         EXPECT_GE(read(p[0], mem_recv, send_len), 0);
420         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
421 }
422
423 TEST_F(tls, recvmsg_single)
424 {
425         char const *test_str = "test_recvmsg_single";
426         int send_len = strlen(test_str) + 1;
427         char buf[20];
428         struct msghdr hdr;
429         struct iovec vec;
430
431         memset(&hdr, 0, sizeof(hdr));
432         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
433         vec.iov_base = (char *)buf;
434         vec.iov_len = send_len;
435         hdr.msg_iovlen = 1;
436         hdr.msg_iov = &vec;
437         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
438         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
439 }
440
441 TEST_F(tls, recvmsg_single_max)
442 {
443         int send_len = TLS_PAYLOAD_MAX_LEN;
444         char send_mem[TLS_PAYLOAD_MAX_LEN];
445         char recv_mem[TLS_PAYLOAD_MAX_LEN];
446         struct iovec vec;
447         struct msghdr hdr;
448
449         EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
450         vec.iov_base = (char *)recv_mem;
451         vec.iov_len = TLS_PAYLOAD_MAX_LEN;
452
453         hdr.msg_iovlen = 1;
454         hdr.msg_iov = &vec;
455         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
456         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
457 }
458
459 TEST_F(tls, recvmsg_multiple)
460 {
461         unsigned int msg_iovlen = 1024;
462         unsigned int len_compared = 0;
463         struct iovec vec[1024];
464         char *iov_base[1024];
465         unsigned int iov_len = 16;
466         int send_len = 1 << 14;
467         char buf[1 << 14];
468         struct msghdr hdr;
469         int i;
470
471         EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
472         for (i = 0; i < msg_iovlen; i++) {
473                 iov_base[i] = (char *)malloc(iov_len);
474                 vec[i].iov_base = iov_base[i];
475                 vec[i].iov_len = iov_len;
476         }
477
478         hdr.msg_iovlen = msg_iovlen;
479         hdr.msg_iov = vec;
480         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
481         for (i = 0; i < msg_iovlen; i++)
482                 len_compared += iov_len;
483
484         for (i = 0; i < msg_iovlen; i++)
485                 free(iov_base[i]);
486 }
487
488 TEST_F(tls, single_send_multiple_recv)
489 {
490         unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
491         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
492         char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
493         char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
494
495         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
496         memset(recv_mem, 0, total_len);
497
498         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
499         EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
500         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
501 }
502
503 TEST_F(tls, multiple_send_single_recv)
504 {
505         unsigned int total_len = 2 * 10;
506         unsigned int send_len = 10;
507         char recv_mem[2 * 10];
508         char send_mem[10];
509
510         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
511         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
512         memset(recv_mem, 0, total_len);
513         EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
514
515         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
516         EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
517 }
518
519 TEST_F(tls, single_send_multiple_recv_non_align)
520 {
521         const unsigned int total_len = 15;
522         const unsigned int recv_len = 10;
523         char recv_mem[recv_len * 2];
524         char send_mem[total_len];
525
526         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
527         memset(recv_mem, 0, total_len);
528
529         EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
530         EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
531         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
532 }
533
534 TEST_F(tls, recv_partial)
535 {
536         char const *test_str = "test_read_partial";
537         char const *test_str_first = "test_read";
538         char const *test_str_second = "_partial";
539         int send_len = strlen(test_str) + 1;
540         char recv_mem[18];
541
542         memset(recv_mem, 0, sizeof(recv_mem));
543         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
544         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_first),
545                        MSG_WAITALL), -1);
546         EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
547         memset(recv_mem, 0, sizeof(recv_mem));
548         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_second),
549                        MSG_WAITALL), -1);
550         EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
551                   0);
552 }
553
554 TEST_F(tls, recv_nonblock)
555 {
556         char buf[4096];
557         bool err;
558
559         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
560         err = (errno == EAGAIN || errno == EWOULDBLOCK);
561         EXPECT_EQ(err, true);
562 }
563
564 TEST_F(tls, recv_peek)
565 {
566         char const *test_str = "test_read_peek";
567         int send_len = strlen(test_str) + 1;
568         char buf[15];
569
570         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
571         EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
572         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
573         memset(buf, 0, sizeof(buf));
574         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
575         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
576 }
577
578 TEST_F(tls, recv_peek_multiple)
579 {
580         char const *test_str = "test_read_peek";
581         int send_len = strlen(test_str) + 1;
582         unsigned int num_peeks = 100;
583         char buf[15];
584         int i;
585
586         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
587         for (i = 0; i < num_peeks; i++) {
588                 EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
589                 EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
590                 memset(buf, 0, sizeof(buf));
591         }
592         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
593         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
594 }
595
596 TEST_F(tls, recv_peek_multiple_records)
597 {
598         char const *test_str = "test_read_peek_mult_recs";
599         char const *test_str_first = "test_read_peek";
600         char const *test_str_second = "_mult_recs";
601         int len;
602         char buf[64];
603
604         len = strlen(test_str_first);
605         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
606
607         len = strlen(test_str_second) + 1;
608         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
609
610         len = strlen(test_str_first);
611         memset(buf, 0, len);
612         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
613
614         /* MSG_PEEK can only peek into the current record. */
615         len = strlen(test_str_first);
616         EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
617
618         len = strlen(test_str) + 1;
619         memset(buf, 0, len);
620         EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
621
622         /* Non-MSG_PEEK will advance strparser (and therefore record)
623          * however.
624          */
625         len = strlen(test_str) + 1;
626         EXPECT_EQ(memcmp(test_str, buf, len), 0);
627
628         /* MSG_MORE will hold current record open, so later MSG_PEEK
629          * will see everything.
630          */
631         len = strlen(test_str_first);
632         EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
633
634         len = strlen(test_str_second) + 1;
635         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
636
637         len = strlen(test_str) + 1;
638         memset(buf, 0, len);
639         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
640
641         len = strlen(test_str) + 1;
642         EXPECT_EQ(memcmp(test_str, buf, len), 0);
643 }
644
645 TEST_F(tls, recv_peek_large_buf_mult_recs)
646 {
647         char const *test_str = "test_read_peek_mult_recs";
648         char const *test_str_first = "test_read_peek";
649         char const *test_str_second = "_mult_recs";
650         int len;
651         char buf[64];
652
653         len = strlen(test_str_first);
654         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
655
656         len = strlen(test_str_second) + 1;
657         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
658
659         len = strlen(test_str) + 1;
660         memset(buf, 0, len);
661         EXPECT_NE((len = recv(self->cfd, buf, len,
662                               MSG_PEEK | MSG_WAITALL)), -1);
663         len = strlen(test_str) + 1;
664         EXPECT_EQ(memcmp(test_str, buf, len), 0);
665 }
666
667 TEST_F(tls, recv_lowat)
668 {
669         char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
670         char recv_mem[20];
671         int lowat = 8;
672
673         EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
674         EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
675
676         memset(recv_mem, 0, 20);
677         EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
678                              &lowat, sizeof(lowat)), 0);
679         EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
680         EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
681         EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
682
683         EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
684         EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
685 }
686
687 TEST_F(tls, pollin)
688 {
689         char const *test_str = "test_poll";
690         struct pollfd fd = { 0, 0, 0 };
691         char buf[10];
692         int send_len = 10;
693
694         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
695         fd.fd = self->cfd;
696         fd.events = POLLIN;
697
698         EXPECT_EQ(poll(&fd, 1, 20), 1);
699         EXPECT_EQ(fd.revents & POLLIN, 1);
700         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
701         /* Test timing out */
702         EXPECT_EQ(poll(&fd, 1, 20), 0);
703 }
704
705 TEST_F(tls, poll_wait)
706 {
707         char const *test_str = "test_poll_wait";
708         int send_len = strlen(test_str) + 1;
709         struct pollfd fd = { 0, 0, 0 };
710         char recv_mem[15];
711
712         fd.fd = self->cfd;
713         fd.events = POLLIN;
714         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
715         /* Set timeout to inf. secs */
716         EXPECT_EQ(poll(&fd, 1, -1), 1);
717         EXPECT_EQ(fd.revents & POLLIN, 1);
718         EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
719 }
720
721 TEST_F(tls, poll_wait_split)
722 {
723         struct pollfd fd = { 0, 0, 0 };
724         char send_mem[20] = {};
725         char recv_mem[15];
726
727         fd.fd = self->cfd;
728         fd.events = POLLIN;
729         /* Send 20 bytes */
730         EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
731                   sizeof(send_mem));
732         /* Poll with inf. timeout */
733         EXPECT_EQ(poll(&fd, 1, -1), 1);
734         EXPECT_EQ(fd.revents & POLLIN, 1);
735         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
736                   sizeof(recv_mem));
737
738         /* Now the remaining 5 bytes of record data are in TLS ULP */
739         fd.fd = self->cfd;
740         fd.events = POLLIN;
741         EXPECT_EQ(poll(&fd, 1, -1), 1);
742         EXPECT_EQ(fd.revents & POLLIN, 1);
743         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
744                   sizeof(send_mem) - sizeof(recv_mem));
745 }
746
747 TEST_F(tls, blocking)
748 {
749         size_t data = 100000;
750         int res = fork();
751
752         EXPECT_NE(res, -1);
753
754         if (res) {
755                 /* parent */
756                 size_t left = data;
757                 char buf[16384];
758                 int status;
759                 int pid2;
760
761                 while (left) {
762                         int res = send(self->fd, buf,
763                                        left > 16384 ? 16384 : left, 0);
764
765                         EXPECT_GE(res, 0);
766                         left -= res;
767                 }
768
769                 pid2 = wait(&status);
770                 EXPECT_EQ(status, 0);
771                 EXPECT_EQ(res, pid2);
772         } else {
773                 /* child */
774                 size_t left = data;
775                 char buf[16384];
776
777                 while (left) {
778                         int res = recv(self->cfd, buf,
779                                        left > 16384 ? 16384 : left, 0);
780
781                         EXPECT_GE(res, 0);
782                         left -= res;
783                 }
784         }
785 }
786
787 TEST_F(tls, nonblocking)
788 {
789         size_t data = 100000;
790         int sendbuf = 100;
791         int flags;
792         int res;
793
794         flags = fcntl(self->fd, F_GETFL, 0);
795         fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
796         fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
797
798         /* Ensure nonblocking behavior by imposing a small send
799          * buffer.
800          */
801         EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
802                              &sendbuf, sizeof(sendbuf)), 0);
803
804         res = fork();
805         EXPECT_NE(res, -1);
806
807         if (res) {
808                 /* parent */
809                 bool eagain = false;
810                 size_t left = data;
811                 char buf[16384];
812                 int status;
813                 int pid2;
814
815                 while (left) {
816                         int res = send(self->fd, buf,
817                                        left > 16384 ? 16384 : left, 0);
818
819                         if (res == -1 && errno == EAGAIN) {
820                                 eagain = true;
821                                 usleep(10000);
822                                 continue;
823                         }
824                         EXPECT_GE(res, 0);
825                         left -= res;
826                 }
827
828                 EXPECT_TRUE(eagain);
829                 pid2 = wait(&status);
830
831                 EXPECT_EQ(status, 0);
832                 EXPECT_EQ(res, pid2);
833         } else {
834                 /* child */
835                 bool eagain = false;
836                 size_t left = data;
837                 char buf[16384];
838
839                 while (left) {
840                         int res = recv(self->cfd, buf,
841                                        left > 16384 ? 16384 : left, 0);
842
843                         if (res == -1 && errno == EAGAIN) {
844                                 eagain = true;
845                                 usleep(10000);
846                                 continue;
847                         }
848                         EXPECT_GE(res, 0);
849                         left -= res;
850                 }
851                 EXPECT_TRUE(eagain);
852         }
853 }
854
855 TEST_F(tls, control_msg)
856 {
857         if (self->notls)
858                 return;
859
860         char cbuf[CMSG_SPACE(sizeof(char))];
861         char const *test_str = "test_read";
862         int cmsg_len = sizeof(char);
863         char record_type = 100;
864         struct cmsghdr *cmsg;
865         struct msghdr msg;
866         int send_len = 10;
867         struct iovec vec;
868         char buf[10];
869
870         vec.iov_base = (char *)test_str;
871         vec.iov_len = 10;
872         memset(&msg, 0, sizeof(struct msghdr));
873         msg.msg_iov = &vec;
874         msg.msg_iovlen = 1;
875         msg.msg_control = cbuf;
876         msg.msg_controllen = sizeof(cbuf);
877         cmsg = CMSG_FIRSTHDR(&msg);
878         cmsg->cmsg_level = SOL_TLS;
879         /* test sending non-record types. */
880         cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
881         cmsg->cmsg_len = CMSG_LEN(cmsg_len);
882         *CMSG_DATA(cmsg) = record_type;
883         msg.msg_controllen = cmsg->cmsg_len;
884
885         EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
886         /* Should fail because we didn't provide a control message */
887         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
888
889         vec.iov_base = buf;
890         EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL | MSG_PEEK), send_len);
891
892         cmsg = CMSG_FIRSTHDR(&msg);
893         EXPECT_NE(cmsg, NULL);
894         EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
895         EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
896         record_type = *((unsigned char *)CMSG_DATA(cmsg));
897         EXPECT_EQ(record_type, 100);
898         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
899
900         /* Recv the message again without MSG_PEEK */
901         record_type = 0;
902         memset(buf, 0, sizeof(buf));
903
904         EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL), send_len);
905         cmsg = CMSG_FIRSTHDR(&msg);
906         EXPECT_NE(cmsg, NULL);
907         EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
908         EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
909         record_type = *((unsigned char *)CMSG_DATA(cmsg));
910         EXPECT_EQ(record_type, 100);
911         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
912 }
913
914 TEST(keysizes) {
915         struct tls12_crypto_info_aes_gcm_256 tls12;
916         struct sockaddr_in addr;
917         int sfd, ret, fd, cfd;
918         socklen_t len;
919         bool notls;
920
921         notls = false;
922         len = sizeof(addr);
923
924         memset(&tls12, 0, sizeof(tls12));
925         tls12.info.version = TLS_1_2_VERSION;
926         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
927
928         addr.sin_family = AF_INET;
929         addr.sin_addr.s_addr = htonl(INADDR_ANY);
930         addr.sin_port = 0;
931
932         fd = socket(AF_INET, SOCK_STREAM, 0);
933         sfd = socket(AF_INET, SOCK_STREAM, 0);
934
935         ret = bind(sfd, &addr, sizeof(addr));
936         ASSERT_EQ(ret, 0);
937         ret = listen(sfd, 10);
938         ASSERT_EQ(ret, 0);
939
940         ret = getsockname(sfd, &addr, &len);
941         ASSERT_EQ(ret, 0);
942
943         ret = connect(fd, &addr, sizeof(addr));
944         ASSERT_EQ(ret, 0);
945
946         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
947         if (ret != 0) {
948                 notls = true;
949                 printf("Failure setting TCP_ULP, testing without tls\n");
950         }
951
952         if (!notls) {
953                 ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
954                                  sizeof(tls12));
955                 EXPECT_EQ(ret, 0);
956         }
957
958         cfd = accept(sfd, &addr, &len);
959         ASSERT_GE(cfd, 0);
960
961         if (!notls) {
962                 ret = setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls",
963                                  sizeof("tls"));
964                 EXPECT_EQ(ret, 0);
965
966                 ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
967                                  sizeof(tls12));
968                 EXPECT_EQ(ret, 0);
969         }
970
971         close(sfd);
972         close(fd);
973         close(cfd);
974 }
975
976 TEST(tls12) {
977         int fd, cfd;
978         bool notls;
979
980         struct tls12_crypto_info_aes_gcm_128 tls12;
981         struct sockaddr_in addr;
982         socklen_t len;
983         int sfd, ret;
984
985         notls = false;
986         len = sizeof(addr);
987
988         memset(&tls12, 0, sizeof(tls12));
989         tls12.info.version = TLS_1_2_VERSION;
990         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_128;
991
992         addr.sin_family = AF_INET;
993         addr.sin_addr.s_addr = htonl(INADDR_ANY);
994         addr.sin_port = 0;
995
996         fd = socket(AF_INET, SOCK_STREAM, 0);
997         sfd = socket(AF_INET, SOCK_STREAM, 0);
998
999         ret = bind(sfd, &addr, sizeof(addr));
1000         ASSERT_EQ(ret, 0);
1001         ret = listen(sfd, 10);
1002         ASSERT_EQ(ret, 0);
1003
1004         ret = getsockname(sfd, &addr, &len);
1005         ASSERT_EQ(ret, 0);
1006
1007         ret = connect(fd, &addr, sizeof(addr));
1008         ASSERT_EQ(ret, 0);
1009
1010         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1011         if (ret != 0) {
1012                 notls = true;
1013                 printf("Failure setting TCP_ULP, testing without tls\n");
1014         }
1015
1016         if (!notls) {
1017                 ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
1018                                  sizeof(tls12));
1019                 ASSERT_EQ(ret, 0);
1020         }
1021
1022         cfd = accept(sfd, &addr, &len);
1023         ASSERT_GE(cfd, 0);
1024
1025         if (!notls) {
1026                 ret = setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls",
1027                                  sizeof("tls"));
1028                 ASSERT_EQ(ret, 0);
1029
1030                 ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
1031                                  sizeof(tls12));
1032                 ASSERT_EQ(ret, 0);
1033         }
1034
1035         close(sfd);
1036
1037         char const *test_str = "test_read";
1038         int send_len = 10;
1039         char buf[10];
1040
1041         send_len = strlen(test_str) + 1;
1042         EXPECT_EQ(send(fd, test_str, send_len, 0), send_len);
1043         EXPECT_NE(recv(cfd, buf, send_len, 0), -1);
1044         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1045
1046         close(fd);
1047         close(cfd);
1048 }
1049
1050 TEST_HARNESS_MAIN