Merge branch 'for-4.14-fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/tj...
[sfrench/cifs-2.6.git] / net / tls / tls_sw.c
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  *
8  * This software is available to you under a choice of one of two
9  * licenses.  You may choose to be licensed under the terms of the GNU
10  * General Public License (GPL) Version 2, available from the file
11  * COPYING in the main directory of this source tree, or the
12  * OpenIB.org BSD license below:
13  *
14  *     Redistribution and use in source and binary forms, with or
15  *     without modification, are permitted provided that the following
16  *     conditions are met:
17  *
18  *      - Redistributions of source code must retain the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer.
21  *
22  *      - Redistributions in binary form must reproduce the above
23  *        copyright notice, this list of conditions and the following
24  *        disclaimer in the documentation and/or other materials
25  *        provided with the distribution.
26  *
27  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34  * SOFTWARE.
35  */
36
37 #include <linux/module.h>
38 #include <crypto/aead.h>
39
40 #include <net/tls.h>
41
42 static inline void tls_make_aad(int recv,
43                                 char *buf,
44                                 size_t size,
45                                 char *record_sequence,
46                                 int record_sequence_size,
47                                 unsigned char record_type)
48 {
49         memcpy(buf, record_sequence, record_sequence_size);
50
51         buf[8] = record_type;
52         buf[9] = TLS_1_2_VERSION_MAJOR;
53         buf[10] = TLS_1_2_VERSION_MINOR;
54         buf[11] = size >> 8;
55         buf[12] = size & 0xFF;
56 }
57
58 static void trim_sg(struct sock *sk, struct scatterlist *sg,
59                     int *sg_num_elem, unsigned int *sg_size, int target_size)
60 {
61         int i = *sg_num_elem - 1;
62         int trim = *sg_size - target_size;
63
64         if (trim <= 0) {
65                 WARN_ON(trim < 0);
66                 return;
67         }
68
69         *sg_size = target_size;
70         while (trim >= sg[i].length) {
71                 trim -= sg[i].length;
72                 sk_mem_uncharge(sk, sg[i].length);
73                 put_page(sg_page(&sg[i]));
74                 i--;
75
76                 if (i < 0)
77                         goto out;
78         }
79
80         sg[i].length -= trim;
81         sk_mem_uncharge(sk, trim);
82
83 out:
84         *sg_num_elem = i + 1;
85 }
86
87 static void trim_both_sgl(struct sock *sk, int target_size)
88 {
89         struct tls_context *tls_ctx = tls_get_ctx(sk);
90         struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
91
92         trim_sg(sk, ctx->sg_plaintext_data,
93                 &ctx->sg_plaintext_num_elem,
94                 &ctx->sg_plaintext_size,
95                 target_size);
96
97         if (target_size > 0)
98                 target_size += tls_ctx->overhead_size;
99
100         trim_sg(sk, ctx->sg_encrypted_data,
101                 &ctx->sg_encrypted_num_elem,
102                 &ctx->sg_encrypted_size,
103                 target_size);
104 }
105
106 static int alloc_sg(struct sock *sk, int len, struct scatterlist *sg,
107                     int *sg_num_elem, unsigned int *sg_size,
108                     int first_coalesce)
109 {
110         struct page_frag *pfrag;
111         unsigned int size = *sg_size;
112         int num_elem = *sg_num_elem, use = 0, rc = 0;
113         struct scatterlist *sge;
114         unsigned int orig_offset;
115
116         len -= size;
117         pfrag = sk_page_frag(sk);
118
119         while (len > 0) {
120                 if (!sk_page_frag_refill(sk, pfrag)) {
121                         rc = -ENOMEM;
122                         goto out;
123                 }
124
125                 use = min_t(int, len, pfrag->size - pfrag->offset);
126
127                 if (!sk_wmem_schedule(sk, use)) {
128                         rc = -ENOMEM;
129                         goto out;
130                 }
131
132                 sk_mem_charge(sk, use);
133                 size += use;
134                 orig_offset = pfrag->offset;
135                 pfrag->offset += use;
136
137                 sge = sg + num_elem - 1;
138                 if (num_elem > first_coalesce && sg_page(sg) == pfrag->page &&
139                     sg->offset + sg->length == orig_offset) {
140                         sg->length += use;
141                 } else {
142                         sge++;
143                         sg_unmark_end(sge);
144                         sg_set_page(sge, pfrag->page, use, orig_offset);
145                         get_page(pfrag->page);
146                         ++num_elem;
147                         if (num_elem == MAX_SKB_FRAGS) {
148                                 rc = -ENOSPC;
149                                 break;
150                         }
151                 }
152
153                 len -= use;
154         }
155         goto out;
156
157 out:
158         *sg_size = size;
159         *sg_num_elem = num_elem;
160         return rc;
161 }
162
163 static int alloc_encrypted_sg(struct sock *sk, int len)
164 {
165         struct tls_context *tls_ctx = tls_get_ctx(sk);
166         struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
167         int rc = 0;
168
169         rc = alloc_sg(sk, len, ctx->sg_encrypted_data,
170                       &ctx->sg_encrypted_num_elem, &ctx->sg_encrypted_size, 0);
171
172         return rc;
173 }
174
175 static int alloc_plaintext_sg(struct sock *sk, int len)
176 {
177         struct tls_context *tls_ctx = tls_get_ctx(sk);
178         struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
179         int rc = 0;
180
181         rc = alloc_sg(sk, len, ctx->sg_plaintext_data,
182                       &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
183                       tls_ctx->pending_open_record_frags);
184
185         return rc;
186 }
187
188 static void free_sg(struct sock *sk, struct scatterlist *sg,
189                     int *sg_num_elem, unsigned int *sg_size)
190 {
191         int i, n = *sg_num_elem;
192
193         for (i = 0; i < n; ++i) {
194                 sk_mem_uncharge(sk, sg[i].length);
195                 put_page(sg_page(&sg[i]));
196         }
197         *sg_num_elem = 0;
198         *sg_size = 0;
199 }
200
201 static void tls_free_both_sg(struct sock *sk)
202 {
203         struct tls_context *tls_ctx = tls_get_ctx(sk);
204         struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
205
206         free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
207                 &ctx->sg_encrypted_size);
208
209         free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
210                 &ctx->sg_plaintext_size);
211 }
212
213 static int tls_do_encryption(struct tls_context *tls_ctx,
214                              struct tls_sw_context *ctx, size_t data_len,
215                              gfp_t flags)
216 {
217         unsigned int req_size = sizeof(struct aead_request) +
218                 crypto_aead_reqsize(ctx->aead_send);
219         struct aead_request *aead_req;
220         int rc;
221
222         aead_req = kmalloc(req_size, flags);
223         if (!aead_req)
224                 return -ENOMEM;
225
226         ctx->sg_encrypted_data[0].offset += tls_ctx->prepend_size;
227         ctx->sg_encrypted_data[0].length -= tls_ctx->prepend_size;
228
229         aead_request_set_tfm(aead_req, ctx->aead_send);
230         aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
231         aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
232                                data_len, tls_ctx->iv);
233         rc = crypto_aead_encrypt(aead_req);
234
235         ctx->sg_encrypted_data[0].offset -= tls_ctx->prepend_size;
236         ctx->sg_encrypted_data[0].length += tls_ctx->prepend_size;
237
238         kfree(aead_req);
239         return rc;
240 }
241
242 static int tls_push_record(struct sock *sk, int flags,
243                            unsigned char record_type)
244 {
245         struct tls_context *tls_ctx = tls_get_ctx(sk);
246         struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
247         int rc;
248
249         sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
250         sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
251
252         tls_make_aad(0, ctx->aad_space, ctx->sg_plaintext_size,
253                      tls_ctx->rec_seq, tls_ctx->rec_seq_size,
254                      record_type);
255
256         tls_fill_prepend(tls_ctx,
257                          page_address(sg_page(&ctx->sg_encrypted_data[0])) +
258                          ctx->sg_encrypted_data[0].offset,
259                          ctx->sg_plaintext_size, record_type);
260
261         tls_ctx->pending_open_record_frags = 0;
262         set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
263
264         rc = tls_do_encryption(tls_ctx, ctx, ctx->sg_plaintext_size,
265                                sk->sk_allocation);
266         if (rc < 0) {
267                 /* If we are called from write_space and
268                  * we fail, we need to set this SOCK_NOSPACE
269                  * to trigger another write_space in the future.
270                  */
271                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
272                 return rc;
273         }
274
275         free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
276                 &ctx->sg_plaintext_size);
277
278         ctx->sg_encrypted_num_elem = 0;
279         ctx->sg_encrypted_size = 0;
280
281         /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
282         rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
283         if (rc < 0 && rc != -EAGAIN)
284                 tls_err_abort(sk);
285
286         tls_advance_record_sn(sk, tls_ctx);
287         return rc;
288 }
289
290 static int tls_sw_push_pending_record(struct sock *sk, int flags)
291 {
292         return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
293 }
294
295 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
296                               int length)
297 {
298         struct tls_context *tls_ctx = tls_get_ctx(sk);
299         struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
300         struct page *pages[MAX_SKB_FRAGS];
301
302         size_t offset;
303         ssize_t copied, use;
304         int i = 0;
305         unsigned int size = ctx->sg_plaintext_size;
306         int num_elem = ctx->sg_plaintext_num_elem;
307         int rc = 0;
308         int maxpages;
309
310         while (length > 0) {
311                 i = 0;
312                 maxpages = ARRAY_SIZE(ctx->sg_plaintext_data) - num_elem;
313                 if (maxpages == 0) {
314                         rc = -EFAULT;
315                         goto out;
316                 }
317                 copied = iov_iter_get_pages(from, pages,
318                                             length,
319                                             maxpages, &offset);
320                 if (copied <= 0) {
321                         rc = -EFAULT;
322                         goto out;
323                 }
324
325                 iov_iter_advance(from, copied);
326
327                 length -= copied;
328                 size += copied;
329                 while (copied) {
330                         use = min_t(int, copied, PAGE_SIZE - offset);
331
332                         sg_set_page(&ctx->sg_plaintext_data[num_elem],
333                                     pages[i], use, offset);
334                         sg_unmark_end(&ctx->sg_plaintext_data[num_elem]);
335                         sk_mem_charge(sk, use);
336
337                         offset = 0;
338                         copied -= use;
339
340                         ++i;
341                         ++num_elem;
342                 }
343         }
344
345 out:
346         ctx->sg_plaintext_size = size;
347         ctx->sg_plaintext_num_elem = num_elem;
348         return rc;
349 }
350
351 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
352                              int bytes)
353 {
354         struct tls_context *tls_ctx = tls_get_ctx(sk);
355         struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
356         struct scatterlist *sg = ctx->sg_plaintext_data;
357         int copy, i, rc = 0;
358
359         for (i = tls_ctx->pending_open_record_frags;
360              i < ctx->sg_plaintext_num_elem; ++i) {
361                 copy = sg[i].length;
362                 if (copy_from_iter(
363                                 page_address(sg_page(&sg[i])) + sg[i].offset,
364                                 copy, from) != copy) {
365                         rc = -EFAULT;
366                         goto out;
367                 }
368                 bytes -= copy;
369
370                 ++tls_ctx->pending_open_record_frags;
371
372                 if (!bytes)
373                         break;
374         }
375
376 out:
377         return rc;
378 }
379
380 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
381 {
382         struct tls_context *tls_ctx = tls_get_ctx(sk);
383         struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
384         int ret = 0;
385         int required_size;
386         long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
387         bool eor = !(msg->msg_flags & MSG_MORE);
388         size_t try_to_copy, copied = 0;
389         unsigned char record_type = TLS_RECORD_TYPE_DATA;
390         int record_room;
391         bool full_record;
392         int orig_size;
393
394         if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
395                 return -ENOTSUPP;
396
397         lock_sock(sk);
398
399         if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
400                 goto send_end;
401
402         if (unlikely(msg->msg_controllen)) {
403                 ret = tls_proccess_cmsg(sk, msg, &record_type);
404                 if (ret)
405                         goto send_end;
406         }
407
408         while (msg_data_left(msg)) {
409                 if (sk->sk_err) {
410                         ret = sk->sk_err;
411                         goto send_end;
412                 }
413
414                 orig_size = ctx->sg_plaintext_size;
415                 full_record = false;
416                 try_to_copy = msg_data_left(msg);
417                 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
418                 if (try_to_copy >= record_room) {
419                         try_to_copy = record_room;
420                         full_record = true;
421                 }
422
423                 required_size = ctx->sg_plaintext_size + try_to_copy +
424                                 tls_ctx->overhead_size;
425
426                 if (!sk_stream_memory_free(sk))
427                         goto wait_for_sndbuf;
428 alloc_encrypted:
429                 ret = alloc_encrypted_sg(sk, required_size);
430                 if (ret) {
431                         if (ret != -ENOSPC)
432                                 goto wait_for_memory;
433
434                         /* Adjust try_to_copy according to the amount that was
435                          * actually allocated. The difference is due
436                          * to max sg elements limit
437                          */
438                         try_to_copy -= required_size - ctx->sg_encrypted_size;
439                         full_record = true;
440                 }
441
442                 if (full_record || eor) {
443                         ret = zerocopy_from_iter(sk, &msg->msg_iter,
444                                                  try_to_copy);
445                         if (ret)
446                                 goto fallback_to_reg_send;
447
448                         copied += try_to_copy;
449                         ret = tls_push_record(sk, msg->msg_flags, record_type);
450                         if (!ret)
451                                 continue;
452                         if (ret == -EAGAIN)
453                                 goto send_end;
454
455                         copied -= try_to_copy;
456 fallback_to_reg_send:
457                         iov_iter_revert(&msg->msg_iter,
458                                         ctx->sg_plaintext_size - orig_size);
459                         trim_sg(sk, ctx->sg_plaintext_data,
460                                 &ctx->sg_plaintext_num_elem,
461                                 &ctx->sg_plaintext_size,
462                                 orig_size);
463                 }
464
465                 required_size = ctx->sg_plaintext_size + try_to_copy;
466 alloc_plaintext:
467                 ret = alloc_plaintext_sg(sk, required_size);
468                 if (ret) {
469                         if (ret != -ENOSPC)
470                                 goto wait_for_memory;
471
472                         /* Adjust try_to_copy according to the amount that was
473                          * actually allocated. The difference is due
474                          * to max sg elements limit
475                          */
476                         try_to_copy -= required_size - ctx->sg_plaintext_size;
477                         full_record = true;
478
479                         trim_sg(sk, ctx->sg_encrypted_data,
480                                 &ctx->sg_encrypted_num_elem,
481                                 &ctx->sg_encrypted_size,
482                                 ctx->sg_plaintext_size +
483                                 tls_ctx->overhead_size);
484                 }
485
486                 ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
487                 if (ret)
488                         goto trim_sgl;
489
490                 copied += try_to_copy;
491                 if (full_record || eor) {
492 push_record:
493                         ret = tls_push_record(sk, msg->msg_flags, record_type);
494                         if (ret) {
495                                 if (ret == -ENOMEM)
496                                         goto wait_for_memory;
497
498                                 goto send_end;
499                         }
500                 }
501
502                 continue;
503
504 wait_for_sndbuf:
505                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
506 wait_for_memory:
507                 ret = sk_stream_wait_memory(sk, &timeo);
508                 if (ret) {
509 trim_sgl:
510                         trim_both_sgl(sk, orig_size);
511                         goto send_end;
512                 }
513
514                 if (tls_is_pending_closed_record(tls_ctx))
515                         goto push_record;
516
517                 if (ctx->sg_encrypted_size < required_size)
518                         goto alloc_encrypted;
519
520                 goto alloc_plaintext;
521         }
522
523 send_end:
524         ret = sk_stream_error(sk, msg->msg_flags, ret);
525
526         release_sock(sk);
527         return copied ? copied : ret;
528 }
529
530 int tls_sw_sendpage(struct sock *sk, struct page *page,
531                     int offset, size_t size, int flags)
532 {
533         struct tls_context *tls_ctx = tls_get_ctx(sk);
534         struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
535         int ret = 0;
536         long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
537         bool eor;
538         size_t orig_size = size;
539         unsigned char record_type = TLS_RECORD_TYPE_DATA;
540         struct scatterlist *sg;
541         bool full_record;
542         int record_room;
543
544         if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
545                       MSG_SENDPAGE_NOTLAST))
546                 return -ENOTSUPP;
547
548         /* No MSG_EOR from splice, only look at MSG_MORE */
549         eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
550
551         lock_sock(sk);
552
553         sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
554
555         if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
556                 goto sendpage_end;
557
558         /* Call the sk_stream functions to manage the sndbuf mem. */
559         while (size > 0) {
560                 size_t copy, required_size;
561
562                 if (sk->sk_err) {
563                         ret = sk->sk_err;
564                         goto sendpage_end;
565                 }
566
567                 full_record = false;
568                 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
569                 copy = size;
570                 if (copy >= record_room) {
571                         copy = record_room;
572                         full_record = true;
573                 }
574                 required_size = ctx->sg_plaintext_size + copy +
575                               tls_ctx->overhead_size;
576
577                 if (!sk_stream_memory_free(sk))
578                         goto wait_for_sndbuf;
579 alloc_payload:
580                 ret = alloc_encrypted_sg(sk, required_size);
581                 if (ret) {
582                         if (ret != -ENOSPC)
583                                 goto wait_for_memory;
584
585                         /* Adjust copy according to the amount that was
586                          * actually allocated. The difference is due
587                          * to max sg elements limit
588                          */
589                         copy -= required_size - ctx->sg_plaintext_size;
590                         full_record = true;
591                 }
592
593                 get_page(page);
594                 sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
595                 sg_set_page(sg, page, copy, offset);
596                 ctx->sg_plaintext_num_elem++;
597
598                 sk_mem_charge(sk, copy);
599                 offset += copy;
600                 size -= copy;
601                 ctx->sg_plaintext_size += copy;
602                 tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
603
604                 if (full_record || eor ||
605                     ctx->sg_plaintext_num_elem ==
606                     ARRAY_SIZE(ctx->sg_plaintext_data)) {
607 push_record:
608                         ret = tls_push_record(sk, flags, record_type);
609                         if (ret) {
610                                 if (ret == -ENOMEM)
611                                         goto wait_for_memory;
612
613                                 goto sendpage_end;
614                         }
615                 }
616                 continue;
617 wait_for_sndbuf:
618                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
619 wait_for_memory:
620                 ret = sk_stream_wait_memory(sk, &timeo);
621                 if (ret) {
622                         trim_both_sgl(sk, ctx->sg_plaintext_size);
623                         goto sendpage_end;
624                 }
625
626                 if (tls_is_pending_closed_record(tls_ctx))
627                         goto push_record;
628
629                 goto alloc_payload;
630         }
631
632 sendpage_end:
633         if (orig_size > size)
634                 ret = orig_size - size;
635         else
636                 ret = sk_stream_error(sk, flags, ret);
637
638         release_sock(sk);
639         return ret;
640 }
641
642 static void tls_sw_free_resources(struct sock *sk)
643 {
644         struct tls_context *tls_ctx = tls_get_ctx(sk);
645         struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
646
647         if (ctx->aead_send)
648                 crypto_free_aead(ctx->aead_send);
649
650         tls_free_both_sg(sk);
651
652         kfree(ctx);
653 }
654
655 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
656 {
657         char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
658         struct tls_crypto_info *crypto_info;
659         struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
660         struct tls_sw_context *sw_ctx;
661         u16 nonce_size, tag_size, iv_size, rec_seq_size;
662         char *iv, *rec_seq;
663         int rc = 0;
664
665         if (!ctx) {
666                 rc = -EINVAL;
667                 goto out;
668         }
669
670         if (ctx->priv_ctx) {
671                 rc = -EEXIST;
672                 goto out;
673         }
674
675         sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
676         if (!sw_ctx) {
677                 rc = -ENOMEM;
678                 goto out;
679         }
680
681         ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
682         ctx->free_resources = tls_sw_free_resources;
683
684         crypto_info = &ctx->crypto_send;
685         switch (crypto_info->cipher_type) {
686         case TLS_CIPHER_AES_GCM_128: {
687                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
688                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
689                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
690                 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
691                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
692                 rec_seq =
693                  ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
694                 gcm_128_info =
695                         (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
696                 break;
697         }
698         default:
699                 rc = -EINVAL;
700                 goto out;
701         }
702
703         ctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
704         ctx->tag_size = tag_size;
705         ctx->overhead_size = ctx->prepend_size + ctx->tag_size;
706         ctx->iv_size = iv_size;
707         ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
708                           GFP_KERNEL);
709         if (!ctx->iv) {
710                 rc = -ENOMEM;
711                 goto out;
712         }
713         memcpy(ctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
714         memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
715         ctx->rec_seq_size = rec_seq_size;
716         ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
717         if (!ctx->rec_seq) {
718                 rc = -ENOMEM;
719                 goto free_iv;
720         }
721         memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
722
723         sg_init_table(sw_ctx->sg_encrypted_data,
724                       ARRAY_SIZE(sw_ctx->sg_encrypted_data));
725         sg_init_table(sw_ctx->sg_plaintext_data,
726                       ARRAY_SIZE(sw_ctx->sg_plaintext_data));
727
728         sg_init_table(sw_ctx->sg_aead_in, 2);
729         sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
730                    sizeof(sw_ctx->aad_space));
731         sg_unmark_end(&sw_ctx->sg_aead_in[1]);
732         sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
733         sg_init_table(sw_ctx->sg_aead_out, 2);
734         sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
735                    sizeof(sw_ctx->aad_space));
736         sg_unmark_end(&sw_ctx->sg_aead_out[1]);
737         sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
738
739         if (!sw_ctx->aead_send) {
740                 sw_ctx->aead_send = crypto_alloc_aead("gcm(aes)", 0, 0);
741                 if (IS_ERR(sw_ctx->aead_send)) {
742                         rc = PTR_ERR(sw_ctx->aead_send);
743                         sw_ctx->aead_send = NULL;
744                         goto free_rec_seq;
745                 }
746         }
747
748         ctx->push_pending_record = tls_sw_push_pending_record;
749
750         memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
751
752         rc = crypto_aead_setkey(sw_ctx->aead_send, keyval,
753                                 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
754         if (rc)
755                 goto free_aead;
756
757         rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tag_size);
758         if (!rc)
759                 goto out;
760
761 free_aead:
762         crypto_free_aead(sw_ctx->aead_send);
763         sw_ctx->aead_send = NULL;
764 free_rec_seq:
765         kfree(ctx->rec_seq);
766         ctx->rec_seq = NULL;
767 free_iv:
768         kfree(ctx->iv);
769         ctx->iv = NULL;
770 out:
771         return rc;
772 }