Merge tag 'dlm-4.14' of git://git.kernel.org/pub/scm/linux/kernel/git/teigland/linux-dlm
[sfrench/cifs-2.6.git] / crypto / algif_skcipher.c
1 /*
2  * algif_skcipher: User-space interface for skcipher algorithms
3  *
4  * This file provides the user-space API for symmetric key ciphers.
5  *
6  * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
7  *
8  * This program is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by the Free
10  * Software Foundation; either version 2 of the License, or (at your option)
11  * any later version.
12  *
13  */
14
15 #include <crypto/scatterwalk.h>
16 #include <crypto/skcipher.h>
17 #include <crypto/if_alg.h>
18 #include <linux/init.h>
19 #include <linux/list.h>
20 #include <linux/kernel.h>
21 #include <linux/sched/signal.h>
22 #include <linux/mm.h>
23 #include <linux/module.h>
24 #include <linux/net.h>
25 #include <net/sock.h>
26
27 struct skcipher_sg_list {
28         struct list_head list;
29
30         int cur;
31
32         struct scatterlist sg[0];
33 };
34
35 struct skcipher_tfm {
36         struct crypto_skcipher *skcipher;
37         bool has_key;
38 };
39
40 struct skcipher_ctx {
41         struct list_head tsgl;
42         struct af_alg_sgl rsgl;
43
44         void *iv;
45
46         struct af_alg_completion completion;
47
48         atomic_t inflight;
49         size_t used;
50
51         unsigned int len;
52         bool more;
53         bool merge;
54         bool enc;
55
56         struct skcipher_request req;
57 };
58
59 struct skcipher_async_rsgl {
60         struct af_alg_sgl sgl;
61         struct list_head list;
62 };
63
64 struct skcipher_async_req {
65         struct kiocb *iocb;
66         struct skcipher_async_rsgl first_sgl;
67         struct list_head list;
68         struct scatterlist *tsg;
69         atomic_t *inflight;
70         struct skcipher_request req;
71 };
72
73 #define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
74                       sizeof(struct scatterlist) - 1)
75
76 static void skcipher_free_async_sgls(struct skcipher_async_req *sreq)
77 {
78         struct skcipher_async_rsgl *rsgl, *tmp;
79         struct scatterlist *sgl;
80         struct scatterlist *sg;
81         int i, n;
82
83         list_for_each_entry_safe(rsgl, tmp, &sreq->list, list) {
84                 af_alg_free_sg(&rsgl->sgl);
85                 if (rsgl != &sreq->first_sgl)
86                         kfree(rsgl);
87         }
88         sgl = sreq->tsg;
89         n = sg_nents(sgl);
90         for_each_sg(sgl, sg, n, i) {
91                 struct page *page = sg_page(sg);
92
93                 /* some SGs may not have a page mapped */
94                 if (page && page_ref_count(page))
95                         put_page(page);
96         }
97
98         kfree(sreq->tsg);
99 }
100
101 static void skcipher_async_cb(struct crypto_async_request *req, int err)
102 {
103         struct skcipher_async_req *sreq = req->data;
104         struct kiocb *iocb = sreq->iocb;
105
106         atomic_dec(sreq->inflight);
107         skcipher_free_async_sgls(sreq);
108         kzfree(sreq);
109         iocb->ki_complete(iocb, err, err);
110 }
111
112 static inline int skcipher_sndbuf(struct sock *sk)
113 {
114         struct alg_sock *ask = alg_sk(sk);
115         struct skcipher_ctx *ctx = ask->private;
116
117         return max_t(int, max_t(int, sk->sk_sndbuf & PAGE_MASK, PAGE_SIZE) -
118                           ctx->used, 0);
119 }
120
121 static inline bool skcipher_writable(struct sock *sk)
122 {
123         return PAGE_SIZE <= skcipher_sndbuf(sk);
124 }
125
126 static int skcipher_alloc_sgl(struct sock *sk)
127 {
128         struct alg_sock *ask = alg_sk(sk);
129         struct skcipher_ctx *ctx = ask->private;
130         struct skcipher_sg_list *sgl;
131         struct scatterlist *sg = NULL;
132
133         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
134         if (!list_empty(&ctx->tsgl))
135                 sg = sgl->sg;
136
137         if (!sg || sgl->cur >= MAX_SGL_ENTS) {
138                 sgl = sock_kmalloc(sk, sizeof(*sgl) +
139                                        sizeof(sgl->sg[0]) * (MAX_SGL_ENTS + 1),
140                                    GFP_KERNEL);
141                 if (!sgl)
142                         return -ENOMEM;
143
144                 sg_init_table(sgl->sg, MAX_SGL_ENTS + 1);
145                 sgl->cur = 0;
146
147                 if (sg)
148                         sg_chain(sg, MAX_SGL_ENTS + 1, sgl->sg);
149
150                 list_add_tail(&sgl->list, &ctx->tsgl);
151         }
152
153         return 0;
154 }
155
156 static void skcipher_pull_sgl(struct sock *sk, size_t used, int put)
157 {
158         struct alg_sock *ask = alg_sk(sk);
159         struct skcipher_ctx *ctx = ask->private;
160         struct skcipher_sg_list *sgl;
161         struct scatterlist *sg;
162         int i;
163
164         while (!list_empty(&ctx->tsgl)) {
165                 sgl = list_first_entry(&ctx->tsgl, struct skcipher_sg_list,
166                                        list);
167                 sg = sgl->sg;
168
169                 for (i = 0; i < sgl->cur; i++) {
170                         size_t plen = min_t(size_t, used, sg[i].length);
171
172                         if (!sg_page(sg + i))
173                                 continue;
174
175                         sg[i].length -= plen;
176                         sg[i].offset += plen;
177
178                         used -= plen;
179                         ctx->used -= plen;
180
181                         if (sg[i].length)
182                                 return;
183                         if (put)
184                                 put_page(sg_page(sg + i));
185                         sg_assign_page(sg + i, NULL);
186                 }
187
188                 list_del(&sgl->list);
189                 sock_kfree_s(sk, sgl,
190                              sizeof(*sgl) + sizeof(sgl->sg[0]) *
191                                             (MAX_SGL_ENTS + 1));
192         }
193
194         if (!ctx->used)
195                 ctx->merge = 0;
196 }
197
198 static void skcipher_free_sgl(struct sock *sk)
199 {
200         struct alg_sock *ask = alg_sk(sk);
201         struct skcipher_ctx *ctx = ask->private;
202
203         skcipher_pull_sgl(sk, ctx->used, 1);
204 }
205
206 static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
207 {
208         DEFINE_WAIT_FUNC(wait, woken_wake_function);
209         int err = -ERESTARTSYS;
210         long timeout;
211
212         if (flags & MSG_DONTWAIT)
213                 return -EAGAIN;
214
215         sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
216
217         add_wait_queue(sk_sleep(sk), &wait);
218         for (;;) {
219                 if (signal_pending(current))
220                         break;
221                 timeout = MAX_SCHEDULE_TIMEOUT;
222                 if (sk_wait_event(sk, &timeout, skcipher_writable(sk), &wait)) {
223                         err = 0;
224                         break;
225                 }
226         }
227         remove_wait_queue(sk_sleep(sk), &wait);
228
229         return err;
230 }
231
232 static void skcipher_wmem_wakeup(struct sock *sk)
233 {
234         struct socket_wq *wq;
235
236         if (!skcipher_writable(sk))
237                 return;
238
239         rcu_read_lock();
240         wq = rcu_dereference(sk->sk_wq);
241         if (skwq_has_sleeper(wq))
242                 wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
243                                                            POLLRDNORM |
244                                                            POLLRDBAND);
245         sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
246         rcu_read_unlock();
247 }
248
249 static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
250 {
251         DEFINE_WAIT_FUNC(wait, woken_wake_function);
252         struct alg_sock *ask = alg_sk(sk);
253         struct skcipher_ctx *ctx = ask->private;
254         long timeout;
255         int err = -ERESTARTSYS;
256
257         if (flags & MSG_DONTWAIT) {
258                 return -EAGAIN;
259         }
260
261         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
262
263         add_wait_queue(sk_sleep(sk), &wait);
264         for (;;) {
265                 if (signal_pending(current))
266                         break;
267                 timeout = MAX_SCHEDULE_TIMEOUT;
268                 if (sk_wait_event(sk, &timeout, ctx->used, &wait)) {
269                         err = 0;
270                         break;
271                 }
272         }
273         remove_wait_queue(sk_sleep(sk), &wait);
274
275         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
276
277         return err;
278 }
279
280 static void skcipher_data_wakeup(struct sock *sk)
281 {
282         struct alg_sock *ask = alg_sk(sk);
283         struct skcipher_ctx *ctx = ask->private;
284         struct socket_wq *wq;
285
286         if (!ctx->used)
287                 return;
288
289         rcu_read_lock();
290         wq = rcu_dereference(sk->sk_wq);
291         if (skwq_has_sleeper(wq))
292                 wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
293                                                            POLLRDNORM |
294                                                            POLLRDBAND);
295         sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
296         rcu_read_unlock();
297 }
298
299 static int skcipher_sendmsg(struct socket *sock, struct msghdr *msg,
300                             size_t size)
301 {
302         struct sock *sk = sock->sk;
303         struct alg_sock *ask = alg_sk(sk);
304         struct sock *psk = ask->parent;
305         struct alg_sock *pask = alg_sk(psk);
306         struct skcipher_ctx *ctx = ask->private;
307         struct skcipher_tfm *skc = pask->private;
308         struct crypto_skcipher *tfm = skc->skcipher;
309         unsigned ivsize = crypto_skcipher_ivsize(tfm);
310         struct skcipher_sg_list *sgl;
311         struct af_alg_control con = {};
312         long copied = 0;
313         bool enc = 0;
314         bool init = 0;
315         int err;
316         int i;
317
318         if (msg->msg_controllen) {
319                 err = af_alg_cmsg_send(msg, &con);
320                 if (err)
321                         return err;
322
323                 init = 1;
324                 switch (con.op) {
325                 case ALG_OP_ENCRYPT:
326                         enc = 1;
327                         break;
328                 case ALG_OP_DECRYPT:
329                         enc = 0;
330                         break;
331                 default:
332                         return -EINVAL;
333                 }
334
335                 if (con.iv && con.iv->ivlen != ivsize)
336                         return -EINVAL;
337         }
338
339         err = -EINVAL;
340
341         lock_sock(sk);
342         if (!ctx->more && ctx->used)
343                 goto unlock;
344
345         if (init) {
346                 ctx->enc = enc;
347                 if (con.iv)
348                         memcpy(ctx->iv, con.iv->iv, ivsize);
349         }
350
351         while (size) {
352                 struct scatterlist *sg;
353                 unsigned long len = size;
354                 size_t plen;
355
356                 if (ctx->merge) {
357                         sgl = list_entry(ctx->tsgl.prev,
358                                          struct skcipher_sg_list, list);
359                         sg = sgl->sg + sgl->cur - 1;
360                         len = min_t(unsigned long, len,
361                                     PAGE_SIZE - sg->offset - sg->length);
362
363                         err = memcpy_from_msg(page_address(sg_page(sg)) +
364                                               sg->offset + sg->length,
365                                               msg, len);
366                         if (err)
367                                 goto unlock;
368
369                         sg->length += len;
370                         ctx->merge = (sg->offset + sg->length) &
371                                      (PAGE_SIZE - 1);
372
373                         ctx->used += len;
374                         copied += len;
375                         size -= len;
376                         continue;
377                 }
378
379                 if (!skcipher_writable(sk)) {
380                         err = skcipher_wait_for_wmem(sk, msg->msg_flags);
381                         if (err)
382                                 goto unlock;
383                 }
384
385                 len = min_t(unsigned long, len, skcipher_sndbuf(sk));
386
387                 err = skcipher_alloc_sgl(sk);
388                 if (err)
389                         goto unlock;
390
391                 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
392                 sg = sgl->sg;
393                 if (sgl->cur)
394                         sg_unmark_end(sg + sgl->cur - 1);
395                 do {
396                         i = sgl->cur;
397                         plen = min_t(size_t, len, PAGE_SIZE);
398
399                         sg_assign_page(sg + i, alloc_page(GFP_KERNEL));
400                         err = -ENOMEM;
401                         if (!sg_page(sg + i))
402                                 goto unlock;
403
404                         err = memcpy_from_msg(page_address(sg_page(sg + i)),
405                                               msg, plen);
406                         if (err) {
407                                 __free_page(sg_page(sg + i));
408                                 sg_assign_page(sg + i, NULL);
409                                 goto unlock;
410                         }
411
412                         sg[i].length = plen;
413                         len -= plen;
414                         ctx->used += plen;
415                         copied += plen;
416                         size -= plen;
417                         sgl->cur++;
418                 } while (len && sgl->cur < MAX_SGL_ENTS);
419
420                 if (!size)
421                         sg_mark_end(sg + sgl->cur - 1);
422
423                 ctx->merge = plen & (PAGE_SIZE - 1);
424         }
425
426         err = 0;
427
428         ctx->more = msg->msg_flags & MSG_MORE;
429
430 unlock:
431         skcipher_data_wakeup(sk);
432         release_sock(sk);
433
434         return copied ?: err;
435 }
436
437 static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
438                                  int offset, size_t size, int flags)
439 {
440         struct sock *sk = sock->sk;
441         struct alg_sock *ask = alg_sk(sk);
442         struct skcipher_ctx *ctx = ask->private;
443         struct skcipher_sg_list *sgl;
444         int err = -EINVAL;
445
446         if (flags & MSG_SENDPAGE_NOTLAST)
447                 flags |= MSG_MORE;
448
449         lock_sock(sk);
450         if (!ctx->more && ctx->used)
451                 goto unlock;
452
453         if (!size)
454                 goto done;
455
456         if (!skcipher_writable(sk)) {
457                 err = skcipher_wait_for_wmem(sk, flags);
458                 if (err)
459                         goto unlock;
460         }
461
462         err = skcipher_alloc_sgl(sk);
463         if (err)
464                 goto unlock;
465
466         ctx->merge = 0;
467         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
468
469         if (sgl->cur)
470                 sg_unmark_end(sgl->sg + sgl->cur - 1);
471
472         sg_mark_end(sgl->sg + sgl->cur);
473         get_page(page);
474         sg_set_page(sgl->sg + sgl->cur, page, size, offset);
475         sgl->cur++;
476         ctx->used += size;
477
478 done:
479         ctx->more = flags & MSG_MORE;
480
481 unlock:
482         skcipher_data_wakeup(sk);
483         release_sock(sk);
484
485         return err ?: size;
486 }
487
488 static int skcipher_all_sg_nents(struct skcipher_ctx *ctx)
489 {
490         struct skcipher_sg_list *sgl;
491         struct scatterlist *sg;
492         int nents = 0;
493
494         list_for_each_entry(sgl, &ctx->tsgl, list) {
495                 sg = sgl->sg;
496
497                 while (!sg->length)
498                         sg++;
499
500                 nents += sg_nents(sg);
501         }
502         return nents;
503 }
504
505 static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
506                                   int flags)
507 {
508         struct sock *sk = sock->sk;
509         struct alg_sock *ask = alg_sk(sk);
510         struct sock *psk = ask->parent;
511         struct alg_sock *pask = alg_sk(psk);
512         struct skcipher_ctx *ctx = ask->private;
513         struct skcipher_tfm *skc = pask->private;
514         struct crypto_skcipher *tfm = skc->skcipher;
515         struct skcipher_sg_list *sgl;
516         struct scatterlist *sg;
517         struct skcipher_async_req *sreq;
518         struct skcipher_request *req;
519         struct skcipher_async_rsgl *last_rsgl = NULL;
520         unsigned int txbufs = 0, len = 0, tx_nents;
521         unsigned int reqsize = crypto_skcipher_reqsize(tfm);
522         unsigned int ivsize = crypto_skcipher_ivsize(tfm);
523         int err = -ENOMEM;
524         bool mark = false;
525         char *iv;
526
527         sreq = kzalloc(sizeof(*sreq) + reqsize + ivsize, GFP_KERNEL);
528         if (unlikely(!sreq))
529                 goto out;
530
531         req = &sreq->req;
532         iv = (char *)(req + 1) + reqsize;
533         sreq->iocb = msg->msg_iocb;
534         INIT_LIST_HEAD(&sreq->list);
535         sreq->inflight = &ctx->inflight;
536
537         lock_sock(sk);
538         tx_nents = skcipher_all_sg_nents(ctx);
539         sreq->tsg = kcalloc(tx_nents, sizeof(*sg), GFP_KERNEL);
540         if (unlikely(!sreq->tsg))
541                 goto unlock;
542         sg_init_table(sreq->tsg, tx_nents);
543         memcpy(iv, ctx->iv, ivsize);
544         skcipher_request_set_tfm(req, tfm);
545         skcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP,
546                                       skcipher_async_cb, sreq);
547
548         while (iov_iter_count(&msg->msg_iter)) {
549                 struct skcipher_async_rsgl *rsgl;
550                 int used;
551
552                 if (!ctx->used) {
553                         err = skcipher_wait_for_data(sk, flags);
554                         if (err)
555                                 goto free;
556                 }
557                 sgl = list_first_entry(&ctx->tsgl,
558                                        struct skcipher_sg_list, list);
559                 sg = sgl->sg;
560
561                 while (!sg->length)
562                         sg++;
563
564                 used = min_t(unsigned long, ctx->used,
565                              iov_iter_count(&msg->msg_iter));
566                 used = min_t(unsigned long, used, sg->length);
567
568                 if (txbufs == tx_nents) {
569                         struct scatterlist *tmp;
570                         int x;
571                         /* Ran out of tx slots in async request
572                          * need to expand */
573                         tmp = kcalloc(tx_nents * 2, sizeof(*tmp),
574                                       GFP_KERNEL);
575                         if (!tmp) {
576                                 err = -ENOMEM;
577                                 goto free;
578                         }
579
580                         sg_init_table(tmp, tx_nents * 2);
581                         for (x = 0; x < tx_nents; x++)
582                                 sg_set_page(&tmp[x], sg_page(&sreq->tsg[x]),
583                                             sreq->tsg[x].length,
584                                             sreq->tsg[x].offset);
585                         kfree(sreq->tsg);
586                         sreq->tsg = tmp;
587                         tx_nents *= 2;
588                         mark = true;
589                 }
590                 /* Need to take over the tx sgl from ctx
591                  * to the asynch req - these sgls will be freed later */
592                 sg_set_page(sreq->tsg + txbufs++, sg_page(sg), sg->length,
593                             sg->offset);
594
595                 if (list_empty(&sreq->list)) {
596                         rsgl = &sreq->first_sgl;
597                         list_add_tail(&rsgl->list, &sreq->list);
598                 } else {
599                         rsgl = kmalloc(sizeof(*rsgl), GFP_KERNEL);
600                         if (!rsgl) {
601                                 err = -ENOMEM;
602                                 goto free;
603                         }
604                         list_add_tail(&rsgl->list, &sreq->list);
605                 }
606
607                 used = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, used);
608                 err = used;
609                 if (used < 0)
610                         goto free;
611                 if (last_rsgl)
612                         af_alg_link_sg(&last_rsgl->sgl, &rsgl->sgl);
613
614                 last_rsgl = rsgl;
615                 len += used;
616                 skcipher_pull_sgl(sk, used, 0);
617                 iov_iter_advance(&msg->msg_iter, used);
618         }
619
620         if (mark)
621                 sg_mark_end(sreq->tsg + txbufs - 1);
622
623         skcipher_request_set_crypt(req, sreq->tsg, sreq->first_sgl.sgl.sg,
624                                    len, iv);
625         err = ctx->enc ? crypto_skcipher_encrypt(req) :
626                          crypto_skcipher_decrypt(req);
627         if (err == -EINPROGRESS) {
628                 atomic_inc(&ctx->inflight);
629                 err = -EIOCBQUEUED;
630                 sreq = NULL;
631                 goto unlock;
632         }
633 free:
634         skcipher_free_async_sgls(sreq);
635 unlock:
636         skcipher_wmem_wakeup(sk);
637         release_sock(sk);
638         kzfree(sreq);
639 out:
640         return err;
641 }
642
643 static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
644                                  int flags)
645 {
646         struct sock *sk = sock->sk;
647         struct alg_sock *ask = alg_sk(sk);
648         struct sock *psk = ask->parent;
649         struct alg_sock *pask = alg_sk(psk);
650         struct skcipher_ctx *ctx = ask->private;
651         struct skcipher_tfm *skc = pask->private;
652         struct crypto_skcipher *tfm = skc->skcipher;
653         unsigned bs = crypto_skcipher_blocksize(tfm);
654         struct skcipher_sg_list *sgl;
655         struct scatterlist *sg;
656         int err = -EAGAIN;
657         int used;
658         long copied = 0;
659
660         lock_sock(sk);
661         while (msg_data_left(msg)) {
662                 if (!ctx->used) {
663                         err = skcipher_wait_for_data(sk, flags);
664                         if (err)
665                                 goto unlock;
666                 }
667
668                 used = min_t(unsigned long, ctx->used, msg_data_left(msg));
669
670                 used = af_alg_make_sg(&ctx->rsgl, &msg->msg_iter, used);
671                 err = used;
672                 if (err < 0)
673                         goto unlock;
674
675                 if (ctx->more || used < ctx->used)
676                         used -= used % bs;
677
678                 err = -EINVAL;
679                 if (!used)
680                         goto free;
681
682                 sgl = list_first_entry(&ctx->tsgl,
683                                        struct skcipher_sg_list, list);
684                 sg = sgl->sg;
685
686                 while (!sg->length)
687                         sg++;
688
689                 skcipher_request_set_crypt(&ctx->req, sg, ctx->rsgl.sg, used,
690                                            ctx->iv);
691
692                 err = af_alg_wait_for_completion(
693                                 ctx->enc ?
694                                         crypto_skcipher_encrypt(&ctx->req) :
695                                         crypto_skcipher_decrypt(&ctx->req),
696                                 &ctx->completion);
697
698 free:
699                 af_alg_free_sg(&ctx->rsgl);
700
701                 if (err)
702                         goto unlock;
703
704                 copied += used;
705                 skcipher_pull_sgl(sk, used, 1);
706                 iov_iter_advance(&msg->msg_iter, used);
707         }
708
709         err = 0;
710
711 unlock:
712         skcipher_wmem_wakeup(sk);
713         release_sock(sk);
714
715         return copied ?: err;
716 }
717
718 static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
719                             size_t ignored, int flags)
720 {
721         return (msg->msg_iocb && !is_sync_kiocb(msg->msg_iocb)) ?
722                 skcipher_recvmsg_async(sock, msg, flags) :
723                 skcipher_recvmsg_sync(sock, msg, flags);
724 }
725
726 static unsigned int skcipher_poll(struct file *file, struct socket *sock,
727                                   poll_table *wait)
728 {
729         struct sock *sk = sock->sk;
730         struct alg_sock *ask = alg_sk(sk);
731         struct skcipher_ctx *ctx = ask->private;
732         unsigned int mask;
733
734         sock_poll_wait(file, sk_sleep(sk), wait);
735         mask = 0;
736
737         if (ctx->used)
738                 mask |= POLLIN | POLLRDNORM;
739
740         if (skcipher_writable(sk))
741                 mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
742
743         return mask;
744 }
745
746 static struct proto_ops algif_skcipher_ops = {
747         .family         =       PF_ALG,
748
749         .connect        =       sock_no_connect,
750         .socketpair     =       sock_no_socketpair,
751         .getname        =       sock_no_getname,
752         .ioctl          =       sock_no_ioctl,
753         .listen         =       sock_no_listen,
754         .shutdown       =       sock_no_shutdown,
755         .getsockopt     =       sock_no_getsockopt,
756         .mmap           =       sock_no_mmap,
757         .bind           =       sock_no_bind,
758         .accept         =       sock_no_accept,
759         .setsockopt     =       sock_no_setsockopt,
760
761         .release        =       af_alg_release,
762         .sendmsg        =       skcipher_sendmsg,
763         .sendpage       =       skcipher_sendpage,
764         .recvmsg        =       skcipher_recvmsg,
765         .poll           =       skcipher_poll,
766 };
767
768 static int skcipher_check_key(struct socket *sock)
769 {
770         int err = 0;
771         struct sock *psk;
772         struct alg_sock *pask;
773         struct skcipher_tfm *tfm;
774         struct sock *sk = sock->sk;
775         struct alg_sock *ask = alg_sk(sk);
776
777         lock_sock(sk);
778         if (ask->refcnt)
779                 goto unlock_child;
780
781         psk = ask->parent;
782         pask = alg_sk(ask->parent);
783         tfm = pask->private;
784
785         err = -ENOKEY;
786         lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
787         if (!tfm->has_key)
788                 goto unlock;
789
790         if (!pask->refcnt++)
791                 sock_hold(psk);
792
793         ask->refcnt = 1;
794         sock_put(psk);
795
796         err = 0;
797
798 unlock:
799         release_sock(psk);
800 unlock_child:
801         release_sock(sk);
802
803         return err;
804 }
805
806 static int skcipher_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
807                                   size_t size)
808 {
809         int err;
810
811         err = skcipher_check_key(sock);
812         if (err)
813                 return err;
814
815         return skcipher_sendmsg(sock, msg, size);
816 }
817
818 static ssize_t skcipher_sendpage_nokey(struct socket *sock, struct page *page,
819                                        int offset, size_t size, int flags)
820 {
821         int err;
822
823         err = skcipher_check_key(sock);
824         if (err)
825                 return err;
826
827         return skcipher_sendpage(sock, page, offset, size, flags);
828 }
829
830 static int skcipher_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
831                                   size_t ignored, int flags)
832 {
833         int err;
834
835         err = skcipher_check_key(sock);
836         if (err)
837                 return err;
838
839         return skcipher_recvmsg(sock, msg, ignored, flags);
840 }
841
842 static struct proto_ops algif_skcipher_ops_nokey = {
843         .family         =       PF_ALG,
844
845         .connect        =       sock_no_connect,
846         .socketpair     =       sock_no_socketpair,
847         .getname        =       sock_no_getname,
848         .ioctl          =       sock_no_ioctl,
849         .listen         =       sock_no_listen,
850         .shutdown       =       sock_no_shutdown,
851         .getsockopt     =       sock_no_getsockopt,
852         .mmap           =       sock_no_mmap,
853         .bind           =       sock_no_bind,
854         .accept         =       sock_no_accept,
855         .setsockopt     =       sock_no_setsockopt,
856
857         .release        =       af_alg_release,
858         .sendmsg        =       skcipher_sendmsg_nokey,
859         .sendpage       =       skcipher_sendpage_nokey,
860         .recvmsg        =       skcipher_recvmsg_nokey,
861         .poll           =       skcipher_poll,
862 };
863
864 static void *skcipher_bind(const char *name, u32 type, u32 mask)
865 {
866         struct skcipher_tfm *tfm;
867         struct crypto_skcipher *skcipher;
868
869         tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
870         if (!tfm)
871                 return ERR_PTR(-ENOMEM);
872
873         skcipher = crypto_alloc_skcipher(name, type, mask);
874         if (IS_ERR(skcipher)) {
875                 kfree(tfm);
876                 return ERR_CAST(skcipher);
877         }
878
879         tfm->skcipher = skcipher;
880
881         return tfm;
882 }
883
884 static void skcipher_release(void *private)
885 {
886         struct skcipher_tfm *tfm = private;
887
888         crypto_free_skcipher(tfm->skcipher);
889         kfree(tfm);
890 }
891
892 static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
893 {
894         struct skcipher_tfm *tfm = private;
895         int err;
896
897         err = crypto_skcipher_setkey(tfm->skcipher, key, keylen);
898         tfm->has_key = !err;
899
900         return err;
901 }
902
903 static void skcipher_wait(struct sock *sk)
904 {
905         struct alg_sock *ask = alg_sk(sk);
906         struct skcipher_ctx *ctx = ask->private;
907         int ctr = 0;
908
909         while (atomic_read(&ctx->inflight) && ctr++ < 100)
910                 msleep(100);
911 }
912
913 static void skcipher_sock_destruct(struct sock *sk)
914 {
915         struct alg_sock *ask = alg_sk(sk);
916         struct skcipher_ctx *ctx = ask->private;
917         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
918
919         if (atomic_read(&ctx->inflight))
920                 skcipher_wait(sk);
921
922         skcipher_free_sgl(sk);
923         sock_kzfree_s(sk, ctx->iv, crypto_skcipher_ivsize(tfm));
924         sock_kfree_s(sk, ctx, ctx->len);
925         af_alg_release_parent(sk);
926 }
927
928 static int skcipher_accept_parent_nokey(void *private, struct sock *sk)
929 {
930         struct skcipher_ctx *ctx;
931         struct alg_sock *ask = alg_sk(sk);
932         struct skcipher_tfm *tfm = private;
933         struct crypto_skcipher *skcipher = tfm->skcipher;
934         unsigned int len = sizeof(*ctx) + crypto_skcipher_reqsize(skcipher);
935
936         ctx = sock_kmalloc(sk, len, GFP_KERNEL);
937         if (!ctx)
938                 return -ENOMEM;
939
940         ctx->iv = sock_kmalloc(sk, crypto_skcipher_ivsize(skcipher),
941                                GFP_KERNEL);
942         if (!ctx->iv) {
943                 sock_kfree_s(sk, ctx, len);
944                 return -ENOMEM;
945         }
946
947         memset(ctx->iv, 0, crypto_skcipher_ivsize(skcipher));
948
949         INIT_LIST_HEAD(&ctx->tsgl);
950         ctx->len = len;
951         ctx->used = 0;
952         ctx->more = 0;
953         ctx->merge = 0;
954         ctx->enc = 0;
955         atomic_set(&ctx->inflight, 0);
956         af_alg_init_completion(&ctx->completion);
957
958         ask->private = ctx;
959
960         skcipher_request_set_tfm(&ctx->req, skcipher);
961         skcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_SLEEP |
962                                                  CRYPTO_TFM_REQ_MAY_BACKLOG,
963                                       af_alg_complete, &ctx->completion);
964
965         sk->sk_destruct = skcipher_sock_destruct;
966
967         return 0;
968 }
969
970 static int skcipher_accept_parent(void *private, struct sock *sk)
971 {
972         struct skcipher_tfm *tfm = private;
973
974         if (!tfm->has_key && crypto_skcipher_has_setkey(tfm->skcipher))
975                 return -ENOKEY;
976
977         return skcipher_accept_parent_nokey(private, sk);
978 }
979
980 static const struct af_alg_type algif_type_skcipher = {
981         .bind           =       skcipher_bind,
982         .release        =       skcipher_release,
983         .setkey         =       skcipher_setkey,
984         .accept         =       skcipher_accept_parent,
985         .accept_nokey   =       skcipher_accept_parent_nokey,
986         .ops            =       &algif_skcipher_ops,
987         .ops_nokey      =       &algif_skcipher_ops_nokey,
988         .name           =       "skcipher",
989         .owner          =       THIS_MODULE
990 };
991
992 static int __init algif_skcipher_init(void)
993 {
994         return af_alg_register_type(&algif_type_skcipher);
995 }
996
997 static void __exit algif_skcipher_exit(void)
998 {
999         int err = af_alg_unregister_type(&algif_type_skcipher);
1000         BUG_ON(err);
1001 }
1002
1003 module_init(algif_skcipher_init);
1004 module_exit(algif_skcipher_exit);
1005 MODULE_LICENSE("GPL");