Merge tag 'm68k-for-v4.20-tag1' of git://git.kernel.org/pub/scm/linux/kernel/git...
[sfrench/cifs-2.6.git] / net / tls / tls_device.c
1 /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
2  *
3  * This software is available to you under a choice of one of two
4  * licenses.  You may choose to be licensed under the terms of the GNU
5  * General Public License (GPL) Version 2, available from the file
6  * COPYING in the main directory of this source tree, or the
7  * OpenIB.org BSD license below:
8  *
9  *     Redistribution and use in source and binary forms, with or
10  *     without modification, are permitted provided that the following
11  *     conditions are met:
12  *
13  *      - Redistributions of source code must retain the above
14  *        copyright notice, this list of conditions and the following
15  *        disclaimer.
16  *
17  *      - Redistributions in binary form must reproduce the above
18  *        copyright notice, this list of conditions and the following
19  *        disclaimer in the documentation and/or other materials
20  *        provided with the distribution.
21  *
22  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
26  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
27  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
28  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29  * SOFTWARE.
30  */
31
32 #include <crypto/aead.h>
33 #include <linux/highmem.h>
34 #include <linux/module.h>
35 #include <linux/netdevice.h>
36 #include <net/dst.h>
37 #include <net/inet_connection_sock.h>
38 #include <net/tcp.h>
39 #include <net/tls.h>
40
41 /* device_offload_lock is used to synchronize tls_dev_add
42  * against NETDEV_DOWN notifications.
43  */
44 static DECLARE_RWSEM(device_offload_lock);
45
46 static void tls_device_gc_task(struct work_struct *work);
47
48 static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
49 static LIST_HEAD(tls_device_gc_list);
50 static LIST_HEAD(tls_device_list);
51 static DEFINE_SPINLOCK(tls_device_lock);
52
53 static void tls_device_free_ctx(struct tls_context *ctx)
54 {
55         if (ctx->tx_conf == TLS_HW)
56                 kfree(tls_offload_ctx_tx(ctx));
57
58         if (ctx->rx_conf == TLS_HW)
59                 kfree(tls_offload_ctx_rx(ctx));
60
61         kfree(ctx);
62 }
63
64 static void tls_device_gc_task(struct work_struct *work)
65 {
66         struct tls_context *ctx, *tmp;
67         unsigned long flags;
68         LIST_HEAD(gc_list);
69
70         spin_lock_irqsave(&tls_device_lock, flags);
71         list_splice_init(&tls_device_gc_list, &gc_list);
72         spin_unlock_irqrestore(&tls_device_lock, flags);
73
74         list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
75                 struct net_device *netdev = ctx->netdev;
76
77                 if (netdev && ctx->tx_conf == TLS_HW) {
78                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
79                                                         TLS_OFFLOAD_CTX_DIR_TX);
80                         dev_put(netdev);
81                         ctx->netdev = NULL;
82                 }
83
84                 list_del(&ctx->list);
85                 tls_device_free_ctx(ctx);
86         }
87 }
88
89 static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
90                               struct net_device *netdev)
91 {
92         if (sk->sk_destruct != tls_device_sk_destruct) {
93                 refcount_set(&ctx->refcount, 1);
94                 dev_hold(netdev);
95                 ctx->netdev = netdev;
96                 spin_lock_irq(&tls_device_lock);
97                 list_add_tail(&ctx->list, &tls_device_list);
98                 spin_unlock_irq(&tls_device_lock);
99
100                 ctx->sk_destruct = sk->sk_destruct;
101                 sk->sk_destruct = tls_device_sk_destruct;
102         }
103 }
104
105 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
106 {
107         unsigned long flags;
108
109         spin_lock_irqsave(&tls_device_lock, flags);
110         list_move_tail(&ctx->list, &tls_device_gc_list);
111
112         /* schedule_work inside the spinlock
113          * to make sure tls_device_down waits for that work.
114          */
115         schedule_work(&tls_device_gc_work);
116
117         spin_unlock_irqrestore(&tls_device_lock, flags);
118 }
119
120 /* We assume that the socket is already connected */
121 static struct net_device *get_netdev_for_sock(struct sock *sk)
122 {
123         struct dst_entry *dst = sk_dst_get(sk);
124         struct net_device *netdev = NULL;
125
126         if (likely(dst)) {
127                 netdev = dst->dev;
128                 dev_hold(netdev);
129         }
130
131         dst_release(dst);
132
133         return netdev;
134 }
135
136 static void destroy_record(struct tls_record_info *record)
137 {
138         int nr_frags = record->num_frags;
139         skb_frag_t *frag;
140
141         while (nr_frags-- > 0) {
142                 frag = &record->frags[nr_frags];
143                 __skb_frag_unref(frag);
144         }
145         kfree(record);
146 }
147
148 static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
149 {
150         struct tls_record_info *info, *temp;
151
152         list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
153                 list_del(&info->list);
154                 destroy_record(info);
155         }
156
157         offload_ctx->retransmit_hint = NULL;
158 }
159
160 static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
161 {
162         struct tls_context *tls_ctx = tls_get_ctx(sk);
163         struct tls_record_info *info, *temp;
164         struct tls_offload_context_tx *ctx;
165         u64 deleted_records = 0;
166         unsigned long flags;
167
168         if (!tls_ctx)
169                 return;
170
171         ctx = tls_offload_ctx_tx(tls_ctx);
172
173         spin_lock_irqsave(&ctx->lock, flags);
174         info = ctx->retransmit_hint;
175         if (info && !before(acked_seq, info->end_seq)) {
176                 ctx->retransmit_hint = NULL;
177                 list_del(&info->list);
178                 destroy_record(info);
179                 deleted_records++;
180         }
181
182         list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
183                 if (before(acked_seq, info->end_seq))
184                         break;
185                 list_del(&info->list);
186
187                 destroy_record(info);
188                 deleted_records++;
189         }
190
191         ctx->unacked_record_sn += deleted_records;
192         spin_unlock_irqrestore(&ctx->lock, flags);
193 }
194
195 /* At this point, there should be no references on this
196  * socket and no in-flight SKBs associated with this
197  * socket, so it is safe to free all the resources.
198  */
199 void tls_device_sk_destruct(struct sock *sk)
200 {
201         struct tls_context *tls_ctx = tls_get_ctx(sk);
202         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
203
204         tls_ctx->sk_destruct(sk);
205
206         if (tls_ctx->tx_conf == TLS_HW) {
207                 if (ctx->open_record)
208                         destroy_record(ctx->open_record);
209                 delete_all_records(ctx);
210                 crypto_free_aead(ctx->aead_send);
211                 clean_acked_data_disable(inet_csk(sk));
212         }
213
214         if (refcount_dec_and_test(&tls_ctx->refcount))
215                 tls_device_queue_ctx_destruction(tls_ctx);
216 }
217 EXPORT_SYMBOL(tls_device_sk_destruct);
218
219 static void tls_append_frag(struct tls_record_info *record,
220                             struct page_frag *pfrag,
221                             int size)
222 {
223         skb_frag_t *frag;
224
225         frag = &record->frags[record->num_frags - 1];
226         if (frag->page.p == pfrag->page &&
227             frag->page_offset + frag->size == pfrag->offset) {
228                 frag->size += size;
229         } else {
230                 ++frag;
231                 frag->page.p = pfrag->page;
232                 frag->page_offset = pfrag->offset;
233                 frag->size = size;
234                 ++record->num_frags;
235                 get_page(pfrag->page);
236         }
237
238         pfrag->offset += size;
239         record->len += size;
240 }
241
242 static int tls_push_record(struct sock *sk,
243                            struct tls_context *ctx,
244                            struct tls_offload_context_tx *offload_ctx,
245                            struct tls_record_info *record,
246                            struct page_frag *pfrag,
247                            int flags,
248                            unsigned char record_type)
249 {
250         struct tcp_sock *tp = tcp_sk(sk);
251         struct page_frag dummy_tag_frag;
252         skb_frag_t *frag;
253         int i;
254
255         /* fill prepend */
256         frag = &record->frags[0];
257         tls_fill_prepend(ctx,
258                          skb_frag_address(frag),
259                          record->len - ctx->tx.prepend_size,
260                          record_type);
261
262         /* HW doesn't care about the data in the tag, because it fills it. */
263         dummy_tag_frag.page = skb_frag_page(frag);
264         dummy_tag_frag.offset = 0;
265
266         tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size);
267         record->end_seq = tp->write_seq + record->len;
268         spin_lock_irq(&offload_ctx->lock);
269         list_add_tail(&record->list, &offload_ctx->records_list);
270         spin_unlock_irq(&offload_ctx->lock);
271         offload_ctx->open_record = NULL;
272         set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
273         tls_advance_record_sn(sk, &ctx->tx);
274
275         for (i = 0; i < record->num_frags; i++) {
276                 frag = &record->frags[i];
277                 sg_unmark_end(&offload_ctx->sg_tx_data[i]);
278                 sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
279                             frag->size, frag->page_offset);
280                 sk_mem_charge(sk, frag->size);
281                 get_page(skb_frag_page(frag));
282         }
283         sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
284
285         /* all ready, send */
286         return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
287 }
288
289 static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
290                                  struct page_frag *pfrag,
291                                  size_t prepend_size)
292 {
293         struct tls_record_info *record;
294         skb_frag_t *frag;
295
296         record = kmalloc(sizeof(*record), GFP_KERNEL);
297         if (!record)
298                 return -ENOMEM;
299
300         frag = &record->frags[0];
301         __skb_frag_set_page(frag, pfrag->page);
302         frag->page_offset = pfrag->offset;
303         skb_frag_size_set(frag, prepend_size);
304
305         get_page(pfrag->page);
306         pfrag->offset += prepend_size;
307
308         record->num_frags = 1;
309         record->len = prepend_size;
310         offload_ctx->open_record = record;
311         return 0;
312 }
313
314 static int tls_do_allocation(struct sock *sk,
315                              struct tls_offload_context_tx *offload_ctx,
316                              struct page_frag *pfrag,
317                              size_t prepend_size)
318 {
319         int ret;
320
321         if (!offload_ctx->open_record) {
322                 if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
323                                                    sk->sk_allocation))) {
324                         sk->sk_prot->enter_memory_pressure(sk);
325                         sk_stream_moderate_sndbuf(sk);
326                         return -ENOMEM;
327                 }
328
329                 ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
330                 if (ret)
331                         return ret;
332
333                 if (pfrag->size > pfrag->offset)
334                         return 0;
335         }
336
337         if (!sk_page_frag_refill(sk, pfrag))
338                 return -ENOMEM;
339
340         return 0;
341 }
342
343 static int tls_push_data(struct sock *sk,
344                          struct iov_iter *msg_iter,
345                          size_t size, int flags,
346                          unsigned char record_type)
347 {
348         struct tls_context *tls_ctx = tls_get_ctx(sk);
349         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
350         int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
351         int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
352         struct tls_record_info *record = ctx->open_record;
353         struct page_frag *pfrag;
354         size_t orig_size = size;
355         u32 max_open_record_len;
356         int copy, rc = 0;
357         bool done = false;
358         long timeo;
359
360         if (flags &
361             ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
362                 return -ENOTSUPP;
363
364         if (sk->sk_err)
365                 return -sk->sk_err;
366
367         timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
368         rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
369         if (rc < 0)
370                 return rc;
371
372         pfrag = sk_page_frag(sk);
373
374         /* TLS_HEADER_SIZE is not counted as part of the TLS record, and
375          * we need to leave room for an authentication tag.
376          */
377         max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
378                               tls_ctx->tx.prepend_size;
379         do {
380                 rc = tls_do_allocation(sk, ctx, pfrag,
381                                        tls_ctx->tx.prepend_size);
382                 if (rc) {
383                         rc = sk_stream_wait_memory(sk, &timeo);
384                         if (!rc)
385                                 continue;
386
387                         record = ctx->open_record;
388                         if (!record)
389                                 break;
390 handle_error:
391                         if (record_type != TLS_RECORD_TYPE_DATA) {
392                                 /* avoid sending partial
393                                  * record with type !=
394                                  * application_data
395                                  */
396                                 size = orig_size;
397                                 destroy_record(record);
398                                 ctx->open_record = NULL;
399                         } else if (record->len > tls_ctx->tx.prepend_size) {
400                                 goto last_record;
401                         }
402
403                         break;
404                 }
405
406                 record = ctx->open_record;
407                 copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
408                 copy = min_t(size_t, copy, (max_open_record_len - record->len));
409
410                 if (copy_from_iter_nocache(page_address(pfrag->page) +
411                                                pfrag->offset,
412                                            copy, msg_iter) != copy) {
413                         rc = -EFAULT;
414                         goto handle_error;
415                 }
416                 tls_append_frag(record, pfrag, copy);
417
418                 size -= copy;
419                 if (!size) {
420 last_record:
421                         tls_push_record_flags = flags;
422                         if (more) {
423                                 tls_ctx->pending_open_record_frags =
424                                                 !!record->num_frags;
425                                 break;
426                         }
427
428                         done = true;
429                 }
430
431                 if (done || record->len >= max_open_record_len ||
432                     (record->num_frags >= MAX_SKB_FRAGS - 1)) {
433                         rc = tls_push_record(sk,
434                                              tls_ctx,
435                                              ctx,
436                                              record,
437                                              pfrag,
438                                              tls_push_record_flags,
439                                              record_type);
440                         if (rc < 0)
441                                 break;
442                 }
443         } while (!done);
444
445         if (orig_size - size > 0)
446                 rc = orig_size - size;
447
448         return rc;
449 }
450
451 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
452 {
453         unsigned char record_type = TLS_RECORD_TYPE_DATA;
454         int rc;
455
456         lock_sock(sk);
457
458         if (unlikely(msg->msg_controllen)) {
459                 rc = tls_proccess_cmsg(sk, msg, &record_type);
460                 if (rc)
461                         goto out;
462         }
463
464         rc = tls_push_data(sk, &msg->msg_iter, size,
465                            msg->msg_flags, record_type);
466
467 out:
468         release_sock(sk);
469         return rc;
470 }
471
472 int tls_device_sendpage(struct sock *sk, struct page *page,
473                         int offset, size_t size, int flags)
474 {
475         struct iov_iter msg_iter;
476         char *kaddr = kmap(page);
477         struct kvec iov;
478         int rc;
479
480         if (flags & MSG_SENDPAGE_NOTLAST)
481                 flags |= MSG_MORE;
482
483         lock_sock(sk);
484
485         if (flags & MSG_OOB) {
486                 rc = -ENOTSUPP;
487                 goto out;
488         }
489
490         iov.iov_base = kaddr + offset;
491         iov.iov_len = size;
492         iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size);
493         rc = tls_push_data(sk, &msg_iter, size,
494                            flags, TLS_RECORD_TYPE_DATA);
495         kunmap(page);
496
497 out:
498         release_sock(sk);
499         return rc;
500 }
501
502 struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
503                                        u32 seq, u64 *p_record_sn)
504 {
505         u64 record_sn = context->hint_record_sn;
506         struct tls_record_info *info;
507
508         info = context->retransmit_hint;
509         if (!info ||
510             before(seq, info->end_seq - info->len)) {
511                 /* if retransmit_hint is irrelevant start
512                  * from the beggining of the list
513                  */
514                 info = list_first_entry(&context->records_list,
515                                         struct tls_record_info, list);
516                 record_sn = context->unacked_record_sn;
517         }
518
519         list_for_each_entry_from(info, &context->records_list, list) {
520                 if (before(seq, info->end_seq)) {
521                         if (!context->retransmit_hint ||
522                             after(info->end_seq,
523                                   context->retransmit_hint->end_seq)) {
524                                 context->hint_record_sn = record_sn;
525                                 context->retransmit_hint = info;
526                         }
527                         *p_record_sn = record_sn;
528                         return info;
529                 }
530                 record_sn++;
531         }
532
533         return NULL;
534 }
535 EXPORT_SYMBOL(tls_get_record);
536
537 static int tls_device_push_pending_record(struct sock *sk, int flags)
538 {
539         struct iov_iter msg_iter;
540
541         iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
542         return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
543 }
544
545 void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
546 {
547         struct tls_context *tls_ctx = tls_get_ctx(sk);
548         struct net_device *netdev = tls_ctx->netdev;
549         struct tls_offload_context_rx *rx_ctx;
550         u32 is_req_pending;
551         s64 resync_req;
552         u32 req_seq;
553
554         if (tls_ctx->rx_conf != TLS_HW)
555                 return;
556
557         rx_ctx = tls_offload_ctx_rx(tls_ctx);
558         resync_req = atomic64_read(&rx_ctx->resync_req);
559         req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1);
560         is_req_pending = resync_req;
561
562         if (unlikely(is_req_pending) && req_seq == seq &&
563             atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
564                 netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk,
565                                                       seq + TLS_HEADER_SIZE - 1,
566                                                       rcd_sn);
567 }
568
569 static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
570 {
571         struct strp_msg *rxm = strp_msg(skb);
572         int err = 0, offset = rxm->offset, copy, nsg;
573         struct sk_buff *skb_iter, *unused;
574         struct scatterlist sg[1];
575         char *orig_buf, *buf;
576
577         orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
578                            TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
579         if (!orig_buf)
580                 return -ENOMEM;
581         buf = orig_buf;
582
583         nsg = skb_cow_data(skb, 0, &unused);
584         if (unlikely(nsg < 0)) {
585                 err = nsg;
586                 goto free_buf;
587         }
588
589         sg_init_table(sg, 1);
590         sg_set_buf(&sg[0], buf,
591                    rxm->full_len + TLS_HEADER_SIZE +
592                    TLS_CIPHER_AES_GCM_128_IV_SIZE);
593         skb_copy_bits(skb, offset, buf,
594                       TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
595
596         /* We are interested only in the decrypted data not the auth */
597         err = decrypt_skb(sk, skb, sg);
598         if (err != -EBADMSG)
599                 goto free_buf;
600         else
601                 err = 0;
602
603         copy = min_t(int, skb_pagelen(skb) - offset,
604                      rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE);
605
606         if (skb->decrypted)
607                 skb_store_bits(skb, offset, buf, copy);
608
609         offset += copy;
610         buf += copy;
611
612         skb_walk_frags(skb, skb_iter) {
613                 copy = min_t(int, skb_iter->len,
614                              rxm->full_len - offset + rxm->offset -
615                              TLS_CIPHER_AES_GCM_128_TAG_SIZE);
616
617                 if (skb_iter->decrypted)
618                         skb_store_bits(skb_iter, offset, buf, copy);
619
620                 offset += copy;
621                 buf += copy;
622         }
623
624 free_buf:
625         kfree(orig_buf);
626         return err;
627 }
628
629 int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
630 {
631         struct tls_context *tls_ctx = tls_get_ctx(sk);
632         struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
633         int is_decrypted = skb->decrypted;
634         int is_encrypted = !is_decrypted;
635         struct sk_buff *skb_iter;
636
637         /* Skip if it is already decrypted */
638         if (ctx->sw.decrypted)
639                 return 0;
640
641         /* Check if all the data is decrypted already */
642         skb_walk_frags(skb, skb_iter) {
643                 is_decrypted &= skb_iter->decrypted;
644                 is_encrypted &= !skb_iter->decrypted;
645         }
646
647         ctx->sw.decrypted |= is_decrypted;
648
649         /* Return immedeatly if the record is either entirely plaintext or
650          * entirely ciphertext. Otherwise handle reencrypt partially decrypted
651          * record.
652          */
653         return (is_encrypted || is_decrypted) ? 0 :
654                 tls_device_reencrypt(sk, skb);
655 }
656
657 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
658 {
659         u16 nonce_size, tag_size, iv_size, rec_seq_size;
660         struct tls_record_info *start_marker_record;
661         struct tls_offload_context_tx *offload_ctx;
662         struct tls_crypto_info *crypto_info;
663         struct net_device *netdev;
664         char *iv, *rec_seq;
665         struct sk_buff *skb;
666         int rc = -EINVAL;
667         __be64 rcd_sn;
668
669         if (!ctx)
670                 goto out;
671
672         if (ctx->priv_ctx_tx) {
673                 rc = -EEXIST;
674                 goto out;
675         }
676
677         start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
678         if (!start_marker_record) {
679                 rc = -ENOMEM;
680                 goto out;
681         }
682
683         offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
684         if (!offload_ctx) {
685                 rc = -ENOMEM;
686                 goto free_marker_record;
687         }
688
689         crypto_info = &ctx->crypto_send.info;
690         switch (crypto_info->cipher_type) {
691         case TLS_CIPHER_AES_GCM_128:
692                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
693                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
694                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
695                 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
696                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
697                 rec_seq =
698                  ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
699                 break;
700         default:
701                 rc = -EINVAL;
702                 goto free_offload_ctx;
703         }
704
705         ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size;
706         ctx->tx.tag_size = tag_size;
707         ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size;
708         ctx->tx.iv_size = iv_size;
709         ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
710                              GFP_KERNEL);
711         if (!ctx->tx.iv) {
712                 rc = -ENOMEM;
713                 goto free_offload_ctx;
714         }
715
716         memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
717
718         ctx->tx.rec_seq_size = rec_seq_size;
719         ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
720         if (!ctx->tx.rec_seq) {
721                 rc = -ENOMEM;
722                 goto free_iv;
723         }
724
725         rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
726         if (rc)
727                 goto free_rec_seq;
728
729         /* start at rec_seq - 1 to account for the start marker record */
730         memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
731         offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
732
733         start_marker_record->end_seq = tcp_sk(sk)->write_seq;
734         start_marker_record->len = 0;
735         start_marker_record->num_frags = 0;
736
737         INIT_LIST_HEAD(&offload_ctx->records_list);
738         list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
739         spin_lock_init(&offload_ctx->lock);
740         sg_init_table(offload_ctx->sg_tx_data,
741                       ARRAY_SIZE(offload_ctx->sg_tx_data));
742
743         clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
744         ctx->push_pending_record = tls_device_push_pending_record;
745
746         /* TLS offload is greatly simplified if we don't send
747          * SKBs where only part of the payload needs to be encrypted.
748          * So mark the last skb in the write queue as end of record.
749          */
750         skb = tcp_write_queue_tail(sk);
751         if (skb)
752                 TCP_SKB_CB(skb)->eor = 1;
753
754         /* We support starting offload on multiple sockets
755          * concurrently, so we only need a read lock here.
756          * This lock must precede get_netdev_for_sock to prevent races between
757          * NETDEV_DOWN and setsockopt.
758          */
759         down_read(&device_offload_lock);
760         netdev = get_netdev_for_sock(sk);
761         if (!netdev) {
762                 pr_err_ratelimited("%s: netdev not found\n", __func__);
763                 rc = -EINVAL;
764                 goto release_lock;
765         }
766
767         if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
768                 rc = -ENOTSUPP;
769                 goto release_netdev;
770         }
771
772         /* Avoid offloading if the device is down
773          * We don't want to offload new flows after
774          * the NETDEV_DOWN event
775          */
776         if (!(netdev->flags & IFF_UP)) {
777                 rc = -EINVAL;
778                 goto release_netdev;
779         }
780
781         ctx->priv_ctx_tx = offload_ctx;
782         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
783                                              &ctx->crypto_send.info,
784                                              tcp_sk(sk)->write_seq);
785         if (rc)
786                 goto release_netdev;
787
788         tls_device_attach(ctx, sk, netdev);
789
790         /* following this assignment tls_is_sk_tx_device_offloaded
791          * will return true and the context might be accessed
792          * by the netdev's xmit function.
793          */
794         smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
795         dev_put(netdev);
796         up_read(&device_offload_lock);
797         goto out;
798
799 release_netdev:
800         dev_put(netdev);
801 release_lock:
802         up_read(&device_offload_lock);
803         clean_acked_data_disable(inet_csk(sk));
804         crypto_free_aead(offload_ctx->aead_send);
805 free_rec_seq:
806         kfree(ctx->tx.rec_seq);
807 free_iv:
808         kfree(ctx->tx.iv);
809 free_offload_ctx:
810         kfree(offload_ctx);
811         ctx->priv_ctx_tx = NULL;
812 free_marker_record:
813         kfree(start_marker_record);
814 out:
815         return rc;
816 }
817
818 int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
819 {
820         struct tls_offload_context_rx *context;
821         struct net_device *netdev;
822         int rc = 0;
823
824         /* We support starting offload on multiple sockets
825          * concurrently, so we only need a read lock here.
826          * This lock must precede get_netdev_for_sock to prevent races between
827          * NETDEV_DOWN and setsockopt.
828          */
829         down_read(&device_offload_lock);
830         netdev = get_netdev_for_sock(sk);
831         if (!netdev) {
832                 pr_err_ratelimited("%s: netdev not found\n", __func__);
833                 rc = -EINVAL;
834                 goto release_lock;
835         }
836
837         if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
838                 pr_err_ratelimited("%s: netdev %s with no TLS offload\n",
839                                    __func__, netdev->name);
840                 rc = -ENOTSUPP;
841                 goto release_netdev;
842         }
843
844         /* Avoid offloading if the device is down
845          * We don't want to offload new flows after
846          * the NETDEV_DOWN event
847          */
848         if (!(netdev->flags & IFF_UP)) {
849                 rc = -EINVAL;
850                 goto release_netdev;
851         }
852
853         context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
854         if (!context) {
855                 rc = -ENOMEM;
856                 goto release_netdev;
857         }
858
859         ctx->priv_ctx_rx = context;
860         rc = tls_set_sw_offload(sk, ctx, 0);
861         if (rc)
862                 goto release_ctx;
863
864         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
865                                              &ctx->crypto_recv.info,
866                                              tcp_sk(sk)->copied_seq);
867         if (rc) {
868                 pr_err_ratelimited("%s: The netdev has refused to offload this socket\n",
869                                    __func__);
870                 goto free_sw_resources;
871         }
872
873         tls_device_attach(ctx, sk, netdev);
874         goto release_netdev;
875
876 free_sw_resources:
877         tls_sw_free_resources_rx(sk);
878 release_ctx:
879         ctx->priv_ctx_rx = NULL;
880 release_netdev:
881         dev_put(netdev);
882 release_lock:
883         up_read(&device_offload_lock);
884         return rc;
885 }
886
887 void tls_device_offload_cleanup_rx(struct sock *sk)
888 {
889         struct tls_context *tls_ctx = tls_get_ctx(sk);
890         struct net_device *netdev;
891
892         down_read(&device_offload_lock);
893         netdev = tls_ctx->netdev;
894         if (!netdev)
895                 goto out;
896
897         if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
898                 pr_err_ratelimited("%s: device is missing NETIF_F_HW_TLS_RX cap\n",
899                                    __func__);
900                 goto out;
901         }
902
903         netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
904                                         TLS_OFFLOAD_CTX_DIR_RX);
905
906         if (tls_ctx->tx_conf != TLS_HW) {
907                 dev_put(netdev);
908                 tls_ctx->netdev = NULL;
909         }
910 out:
911         up_read(&device_offload_lock);
912         kfree(tls_ctx->rx.rec_seq);
913         kfree(tls_ctx->rx.iv);
914         tls_sw_release_resources_rx(sk);
915 }
916
917 static int tls_device_down(struct net_device *netdev)
918 {
919         struct tls_context *ctx, *tmp;
920         unsigned long flags;
921         LIST_HEAD(list);
922
923         /* Request a write lock to block new offload attempts */
924         down_write(&device_offload_lock);
925
926         spin_lock_irqsave(&tls_device_lock, flags);
927         list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
928                 if (ctx->netdev != netdev ||
929                     !refcount_inc_not_zero(&ctx->refcount))
930                         continue;
931
932                 list_move(&ctx->list, &list);
933         }
934         spin_unlock_irqrestore(&tls_device_lock, flags);
935
936         list_for_each_entry_safe(ctx, tmp, &list, list) {
937                 if (ctx->tx_conf == TLS_HW)
938                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
939                                                         TLS_OFFLOAD_CTX_DIR_TX);
940                 if (ctx->rx_conf == TLS_HW)
941                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
942                                                         TLS_OFFLOAD_CTX_DIR_RX);
943                 ctx->netdev = NULL;
944                 dev_put(netdev);
945                 list_del_init(&ctx->list);
946
947                 if (refcount_dec_and_test(&ctx->refcount))
948                         tls_device_free_ctx(ctx);
949         }
950
951         up_write(&device_offload_lock);
952
953         flush_work(&tls_device_gc_work);
954
955         return NOTIFY_DONE;
956 }
957
958 static int tls_dev_event(struct notifier_block *this, unsigned long event,
959                          void *ptr)
960 {
961         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
962
963         if (!(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
964                 return NOTIFY_DONE;
965
966         switch (event) {
967         case NETDEV_REGISTER:
968         case NETDEV_FEAT_CHANGE:
969                 if ((dev->features & NETIF_F_HW_TLS_RX) &&
970                     !dev->tlsdev_ops->tls_dev_resync_rx)
971                         return NOTIFY_BAD;
972
973                 if  (dev->tlsdev_ops &&
974                      dev->tlsdev_ops->tls_dev_add &&
975                      dev->tlsdev_ops->tls_dev_del)
976                         return NOTIFY_DONE;
977                 else
978                         return NOTIFY_BAD;
979         case NETDEV_DOWN:
980                 return tls_device_down(dev);
981         }
982         return NOTIFY_DONE;
983 }
984
985 static struct notifier_block tls_dev_notifier = {
986         .notifier_call  = tls_dev_event,
987 };
988
989 void __init tls_device_init(void)
990 {
991         register_netdevice_notifier(&tls_dev_notifier);
992 }
993
994 void __exit tls_device_cleanup(void)
995 {
996         unregister_netdevice_notifier(&tls_dev_notifier);
997         flush_work(&tls_device_gc_work);
998 }