Merge branch 'nvme-4.19' of git://git.infradead.org/nvme into for-linus
[sfrench/cifs-2.6.git] / kernel / bpf / sockmap.c
1 /* Copyright (c) 2017 Covalent IO, Inc. http://covalent.io
2  *
3  * This program is free software; you can redistribute it and/or
4  * modify it under the terms of version 2 of the GNU General Public
5  * License as published by the Free Software Foundation.
6  *
7  * This program is distributed in the hope that it will be useful, but
8  * WITHOUT ANY WARRANTY; without even the implied warranty of
9  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10  * General Public License for more details.
11  */
12
13 /* A BPF sock_map is used to store sock objects. This is primarly used
14  * for doing socket redirect with BPF helper routines.
15  *
16  * A sock map may have BPF programs attached to it, currently a program
17  * used to parse packets and a program to provide a verdict and redirect
18  * decision on the packet are supported. Any programs attached to a sock
19  * map are inherited by sock objects when they are added to the map. If
20  * no BPF programs are attached the sock object may only be used for sock
21  * redirect.
22  *
23  * A sock object may be in multiple maps, but can only inherit a single
24  * parse or verdict program. If adding a sock object to a map would result
25  * in having multiple parsing programs the update will return an EBUSY error.
26  *
27  * For reference this program is similar to devmap used in XDP context
28  * reviewing these together may be useful. For an example please review
29  * ./samples/bpf/sockmap/.
30  */
31 #include <linux/bpf.h>
32 #include <net/sock.h>
33 #include <linux/filter.h>
34 #include <linux/errno.h>
35 #include <linux/file.h>
36 #include <linux/kernel.h>
37 #include <linux/net.h>
38 #include <linux/skbuff.h>
39 #include <linux/workqueue.h>
40 #include <linux/list.h>
41 #include <linux/mm.h>
42 #include <net/strparser.h>
43 #include <net/tcp.h>
44 #include <linux/ptr_ring.h>
45 #include <net/inet_common.h>
46 #include <linux/sched/signal.h>
47
48 #define SOCK_CREATE_FLAG_MASK \
49         (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
50
51 struct bpf_sock_progs {
52         struct bpf_prog *bpf_tx_msg;
53         struct bpf_prog *bpf_parse;
54         struct bpf_prog *bpf_verdict;
55 };
56
57 struct bpf_stab {
58         struct bpf_map map;
59         struct sock **sock_map;
60         struct bpf_sock_progs progs;
61         raw_spinlock_t lock;
62 };
63
64 struct bucket {
65         struct hlist_head head;
66         raw_spinlock_t lock;
67 };
68
69 struct bpf_htab {
70         struct bpf_map map;
71         struct bucket *buckets;
72         atomic_t count;
73         u32 n_buckets;
74         u32 elem_size;
75         struct bpf_sock_progs progs;
76         struct rcu_head rcu;
77 };
78
79 struct htab_elem {
80         struct rcu_head rcu;
81         struct hlist_node hash_node;
82         u32 hash;
83         struct sock *sk;
84         char key[0];
85 };
86
87 enum smap_psock_state {
88         SMAP_TX_RUNNING,
89 };
90
91 struct smap_psock_map_entry {
92         struct list_head list;
93         struct bpf_map *map;
94         struct sock **entry;
95         struct htab_elem __rcu *hash_link;
96 };
97
98 struct smap_psock {
99         struct rcu_head rcu;
100         refcount_t refcnt;
101
102         /* datapath variables */
103         struct sk_buff_head rxqueue;
104         bool strp_enabled;
105
106         /* datapath error path cache across tx work invocations */
107         int save_rem;
108         int save_off;
109         struct sk_buff *save_skb;
110
111         /* datapath variables for tx_msg ULP */
112         struct sock *sk_redir;
113         int apply_bytes;
114         int cork_bytes;
115         int sg_size;
116         int eval;
117         struct sk_msg_buff *cork;
118         struct list_head ingress;
119
120         struct strparser strp;
121         struct bpf_prog *bpf_tx_msg;
122         struct bpf_prog *bpf_parse;
123         struct bpf_prog *bpf_verdict;
124         struct list_head maps;
125         spinlock_t maps_lock;
126
127         /* Back reference used when sock callback trigger sockmap operations */
128         struct sock *sock;
129         unsigned long state;
130
131         struct work_struct tx_work;
132         struct work_struct gc_work;
133
134         struct proto *sk_proto;
135         void (*save_close)(struct sock *sk, long timeout);
136         void (*save_data_ready)(struct sock *sk);
137         void (*save_write_space)(struct sock *sk);
138 };
139
140 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
141 static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
142                            int nonblock, int flags, int *addr_len);
143 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
144 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
145                             int offset, size_t size, int flags);
146 static void bpf_tcp_close(struct sock *sk, long timeout);
147
148 static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
149 {
150         return rcu_dereference_sk_user_data(sk);
151 }
152
153 static bool bpf_tcp_stream_read(const struct sock *sk)
154 {
155         struct smap_psock *psock;
156         bool empty = true;
157
158         rcu_read_lock();
159         psock = smap_psock_sk(sk);
160         if (unlikely(!psock))
161                 goto out;
162         empty = list_empty(&psock->ingress);
163 out:
164         rcu_read_unlock();
165         return !empty;
166 }
167
168 enum {
169         SOCKMAP_IPV4,
170         SOCKMAP_IPV6,
171         SOCKMAP_NUM_PROTS,
172 };
173
174 enum {
175         SOCKMAP_BASE,
176         SOCKMAP_TX,
177         SOCKMAP_NUM_CONFIGS,
178 };
179
180 static struct proto *saved_tcpv6_prot __read_mostly;
181 static DEFINE_SPINLOCK(tcpv6_prot_lock);
182 static struct proto bpf_tcp_prots[SOCKMAP_NUM_PROTS][SOCKMAP_NUM_CONFIGS];
183 static void build_protos(struct proto prot[SOCKMAP_NUM_CONFIGS],
184                          struct proto *base)
185 {
186         prot[SOCKMAP_BASE]                      = *base;
187         prot[SOCKMAP_BASE].close                = bpf_tcp_close;
188         prot[SOCKMAP_BASE].recvmsg              = bpf_tcp_recvmsg;
189         prot[SOCKMAP_BASE].stream_memory_read   = bpf_tcp_stream_read;
190
191         prot[SOCKMAP_TX]                        = prot[SOCKMAP_BASE];
192         prot[SOCKMAP_TX].sendmsg                = bpf_tcp_sendmsg;
193         prot[SOCKMAP_TX].sendpage               = bpf_tcp_sendpage;
194 }
195
196 static void update_sk_prot(struct sock *sk, struct smap_psock *psock)
197 {
198         int family = sk->sk_family == AF_INET6 ? SOCKMAP_IPV6 : SOCKMAP_IPV4;
199         int conf = psock->bpf_tx_msg ? SOCKMAP_TX : SOCKMAP_BASE;
200
201         sk->sk_prot = &bpf_tcp_prots[family][conf];
202 }
203
204 static int bpf_tcp_init(struct sock *sk)
205 {
206         struct smap_psock *psock;
207
208         rcu_read_lock();
209         psock = smap_psock_sk(sk);
210         if (unlikely(!psock)) {
211                 rcu_read_unlock();
212                 return -EINVAL;
213         }
214
215         if (unlikely(psock->sk_proto)) {
216                 rcu_read_unlock();
217                 return -EBUSY;
218         }
219
220         psock->save_close = sk->sk_prot->close;
221         psock->sk_proto = sk->sk_prot;
222
223         /* Build IPv6 sockmap whenever the address of tcpv6_prot changes */
224         if (sk->sk_family == AF_INET6 &&
225             unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
226                 spin_lock_bh(&tcpv6_prot_lock);
227                 if (likely(sk->sk_prot != saved_tcpv6_prot)) {
228                         build_protos(bpf_tcp_prots[SOCKMAP_IPV6], sk->sk_prot);
229                         smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
230                 }
231                 spin_unlock_bh(&tcpv6_prot_lock);
232         }
233         update_sk_prot(sk, psock);
234         rcu_read_unlock();
235         return 0;
236 }
237
238 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
239 static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
240
241 static void bpf_tcp_release(struct sock *sk)
242 {
243         struct smap_psock *psock;
244
245         rcu_read_lock();
246         psock = smap_psock_sk(sk);
247         if (unlikely(!psock))
248                 goto out;
249
250         if (psock->cork) {
251                 free_start_sg(psock->sock, psock->cork);
252                 kfree(psock->cork);
253                 psock->cork = NULL;
254         }
255
256         if (psock->sk_proto) {
257                 sk->sk_prot = psock->sk_proto;
258                 psock->sk_proto = NULL;
259         }
260 out:
261         rcu_read_unlock();
262 }
263
264 static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
265                                          u32 hash, void *key, u32 key_size)
266 {
267         struct htab_elem *l;
268
269         hlist_for_each_entry_rcu(l, head, hash_node) {
270                 if (l->hash == hash && !memcmp(&l->key, key, key_size))
271                         return l;
272         }
273
274         return NULL;
275 }
276
277 static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
278 {
279         return &htab->buckets[hash & (htab->n_buckets - 1)];
280 }
281
282 static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
283 {
284         return &__select_bucket(htab, hash)->head;
285 }
286
287 static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
288 {
289         atomic_dec(&htab->count);
290         kfree_rcu(l, rcu);
291 }
292
293 static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
294                                                   struct smap_psock *psock)
295 {
296         struct smap_psock_map_entry *e;
297
298         spin_lock_bh(&psock->maps_lock);
299         e = list_first_entry_or_null(&psock->maps,
300                                      struct smap_psock_map_entry,
301                                      list);
302         if (e)
303                 list_del(&e->list);
304         spin_unlock_bh(&psock->maps_lock);
305         return e;
306 }
307
308 static void bpf_tcp_close(struct sock *sk, long timeout)
309 {
310         void (*close_fun)(struct sock *sk, long timeout);
311         struct smap_psock_map_entry *e;
312         struct sk_msg_buff *md, *mtmp;
313         struct smap_psock *psock;
314         struct sock *osk;
315
316         lock_sock(sk);
317         rcu_read_lock();
318         psock = smap_psock_sk(sk);
319         if (unlikely(!psock)) {
320                 rcu_read_unlock();
321                 release_sock(sk);
322                 return sk->sk_prot->close(sk, timeout);
323         }
324
325         /* The psock may be destroyed anytime after exiting the RCU critial
326          * section so by the time we use close_fun the psock may no longer
327          * be valid. However, bpf_tcp_close is called with the sock lock
328          * held so the close hook and sk are still valid.
329          */
330         close_fun = psock->save_close;
331
332         if (psock->cork) {
333                 free_start_sg(psock->sock, psock->cork);
334                 kfree(psock->cork);
335                 psock->cork = NULL;
336         }
337
338         list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
339                 list_del(&md->list);
340                 free_start_sg(psock->sock, md);
341                 kfree(md);
342         }
343
344         e = psock_map_pop(sk, psock);
345         while (e) {
346                 if (e->entry) {
347                         struct bpf_stab *stab = container_of(e->map, struct bpf_stab, map);
348
349                         raw_spin_lock_bh(&stab->lock);
350                         osk = *e->entry;
351                         if (osk == sk) {
352                                 *e->entry = NULL;
353                                 smap_release_sock(psock, sk);
354                         }
355                         raw_spin_unlock_bh(&stab->lock);
356                 } else {
357                         struct htab_elem *link = rcu_dereference(e->hash_link);
358                         struct bpf_htab *htab = container_of(e->map, struct bpf_htab, map);
359                         struct hlist_head *head;
360                         struct htab_elem *l;
361                         struct bucket *b;
362
363                         b = __select_bucket(htab, link->hash);
364                         head = &b->head;
365                         raw_spin_lock_bh(&b->lock);
366                         l = lookup_elem_raw(head,
367                                             link->hash, link->key,
368                                             htab->map.key_size);
369                         /* If another thread deleted this object skip deletion.
370                          * The refcnt on psock may or may not be zero.
371                          */
372                         if (l) {
373                                 hlist_del_rcu(&link->hash_node);
374                                 smap_release_sock(psock, link->sk);
375                                 free_htab_elem(htab, link);
376                         }
377                         raw_spin_unlock_bh(&b->lock);
378                 }
379                 kfree(e);
380                 e = psock_map_pop(sk, psock);
381         }
382         rcu_read_unlock();
383         release_sock(sk);
384         close_fun(sk, timeout);
385 }
386
387 enum __sk_action {
388         __SK_DROP = 0,
389         __SK_PASS,
390         __SK_REDIRECT,
391         __SK_NONE,
392 };
393
394 static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
395         .name           = "bpf_tcp",
396         .uid            = TCP_ULP_BPF,
397         .user_visible   = false,
398         .owner          = NULL,
399         .init           = bpf_tcp_init,
400         .release        = bpf_tcp_release,
401 };
402
403 static int memcopy_from_iter(struct sock *sk,
404                              struct sk_msg_buff *md,
405                              struct iov_iter *from, int bytes)
406 {
407         struct scatterlist *sg = md->sg_data;
408         int i = md->sg_curr, rc = -ENOSPC;
409
410         do {
411                 int copy;
412                 char *to;
413
414                 if (md->sg_copybreak >= sg[i].length) {
415                         md->sg_copybreak = 0;
416
417                         if (++i == MAX_SKB_FRAGS)
418                                 i = 0;
419
420                         if (i == md->sg_end)
421                                 break;
422                 }
423
424                 copy = sg[i].length - md->sg_copybreak;
425                 to = sg_virt(&sg[i]) + md->sg_copybreak;
426                 md->sg_copybreak += copy;
427
428                 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
429                         rc = copy_from_iter_nocache(to, copy, from);
430                 else
431                         rc = copy_from_iter(to, copy, from);
432
433                 if (rc != copy) {
434                         rc = -EFAULT;
435                         goto out;
436                 }
437
438                 bytes -= copy;
439                 if (!bytes)
440                         break;
441
442                 md->sg_copybreak = 0;
443                 if (++i == MAX_SKB_FRAGS)
444                         i = 0;
445         } while (i != md->sg_end);
446 out:
447         md->sg_curr = i;
448         return rc;
449 }
450
451 static int bpf_tcp_push(struct sock *sk, int apply_bytes,
452                         struct sk_msg_buff *md,
453                         int flags, bool uncharge)
454 {
455         bool apply = apply_bytes;
456         struct scatterlist *sg;
457         int offset, ret = 0;
458         struct page *p;
459         size_t size;
460
461         while (1) {
462                 sg = md->sg_data + md->sg_start;
463                 size = (apply && apply_bytes < sg->length) ?
464                         apply_bytes : sg->length;
465                 offset = sg->offset;
466
467                 tcp_rate_check_app_limited(sk);
468                 p = sg_page(sg);
469 retry:
470                 ret = do_tcp_sendpages(sk, p, offset, size, flags);
471                 if (ret != size) {
472                         if (ret > 0) {
473                                 if (apply)
474                                         apply_bytes -= ret;
475
476                                 sg->offset += ret;
477                                 sg->length -= ret;
478                                 size -= ret;
479                                 offset += ret;
480                                 if (uncharge)
481                                         sk_mem_uncharge(sk, ret);
482                                 goto retry;
483                         }
484
485                         return ret;
486                 }
487
488                 if (apply)
489                         apply_bytes -= ret;
490                 sg->offset += ret;
491                 sg->length -= ret;
492                 if (uncharge)
493                         sk_mem_uncharge(sk, ret);
494
495                 if (!sg->length) {
496                         put_page(p);
497                         md->sg_start++;
498                         if (md->sg_start == MAX_SKB_FRAGS)
499                                 md->sg_start = 0;
500                         sg_init_table(sg, 1);
501
502                         if (md->sg_start == md->sg_end)
503                                 break;
504                 }
505
506                 if (apply && !apply_bytes)
507                         break;
508         }
509         return 0;
510 }
511
512 static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
513 {
514         struct scatterlist *sg = md->sg_data + md->sg_start;
515
516         if (md->sg_copy[md->sg_start]) {
517                 md->data = md->data_end = 0;
518         } else {
519                 md->data = sg_virt(sg);
520                 md->data_end = md->data + sg->length;
521         }
522 }
523
524 static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
525 {
526         struct scatterlist *sg = md->sg_data;
527         int i = md->sg_start;
528
529         do {
530                 int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
531
532                 sk_mem_uncharge(sk, uncharge);
533                 bytes -= uncharge;
534                 if (!bytes)
535                         break;
536                 i++;
537                 if (i == MAX_SKB_FRAGS)
538                         i = 0;
539         } while (i != md->sg_end);
540 }
541
542 static void free_bytes_sg(struct sock *sk, int bytes,
543                           struct sk_msg_buff *md, bool charge)
544 {
545         struct scatterlist *sg = md->sg_data;
546         int i = md->sg_start, free;
547
548         while (bytes && sg[i].length) {
549                 free = sg[i].length;
550                 if (bytes < free) {
551                         sg[i].length -= bytes;
552                         sg[i].offset += bytes;
553                         if (charge)
554                                 sk_mem_uncharge(sk, bytes);
555                         break;
556                 }
557
558                 if (charge)
559                         sk_mem_uncharge(sk, sg[i].length);
560                 put_page(sg_page(&sg[i]));
561                 bytes -= sg[i].length;
562                 sg[i].length = 0;
563                 sg[i].page_link = 0;
564                 sg[i].offset = 0;
565                 i++;
566
567                 if (i == MAX_SKB_FRAGS)
568                         i = 0;
569         }
570         md->sg_start = i;
571 }
572
573 static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
574 {
575         struct scatterlist *sg = md->sg_data;
576         int i = start, free = 0;
577
578         while (sg[i].length) {
579                 free += sg[i].length;
580                 sk_mem_uncharge(sk, sg[i].length);
581                 if (!md->skb)
582                         put_page(sg_page(&sg[i]));
583                 sg[i].length = 0;
584                 sg[i].page_link = 0;
585                 sg[i].offset = 0;
586                 i++;
587
588                 if (i == MAX_SKB_FRAGS)
589                         i = 0;
590         }
591         if (md->skb)
592                 consume_skb(md->skb);
593
594         return free;
595 }
596
597 static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
598 {
599         int free = free_sg(sk, md->sg_start, md);
600
601         md->sg_start = md->sg_end;
602         return free;
603 }
604
605 static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
606 {
607         return free_sg(sk, md->sg_curr, md);
608 }
609
610 static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
611 {
612         return ((_rc == SK_PASS) ?
613                (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
614                __SK_DROP);
615 }
616
617 static unsigned int smap_do_tx_msg(struct sock *sk,
618                                    struct smap_psock *psock,
619                                    struct sk_msg_buff *md)
620 {
621         struct bpf_prog *prog;
622         unsigned int rc, _rc;
623
624         preempt_disable();
625         rcu_read_lock();
626
627         /* If the policy was removed mid-send then default to 'accept' */
628         prog = READ_ONCE(psock->bpf_tx_msg);
629         if (unlikely(!prog)) {
630                 _rc = SK_PASS;
631                 goto verdict;
632         }
633
634         bpf_compute_data_pointers_sg(md);
635         md->sk = sk;
636         rc = (*prog->bpf_func)(md, prog->insnsi);
637         psock->apply_bytes = md->apply_bytes;
638
639         /* Moving return codes from UAPI namespace into internal namespace */
640         _rc = bpf_map_msg_verdict(rc, md);
641
642         /* The psock has a refcount on the sock but not on the map and because
643          * we need to drop rcu read lock here its possible the map could be
644          * removed between here and when we need it to execute the sock
645          * redirect. So do the map lookup now for future use.
646          */
647         if (_rc == __SK_REDIRECT) {
648                 if (psock->sk_redir)
649                         sock_put(psock->sk_redir);
650                 psock->sk_redir = do_msg_redirect_map(md);
651                 if (!psock->sk_redir) {
652                         _rc = __SK_DROP;
653                         goto verdict;
654                 }
655                 sock_hold(psock->sk_redir);
656         }
657 verdict:
658         rcu_read_unlock();
659         preempt_enable();
660
661         return _rc;
662 }
663
664 static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
665                            struct smap_psock *psock,
666                            struct sk_msg_buff *md, int flags)
667 {
668         bool apply = apply_bytes;
669         size_t size, copied = 0;
670         struct sk_msg_buff *r;
671         int err = 0, i;
672
673         r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
674         if (unlikely(!r))
675                 return -ENOMEM;
676
677         lock_sock(sk);
678         r->sg_start = md->sg_start;
679         i = md->sg_start;
680
681         do {
682                 size = (apply && apply_bytes < md->sg_data[i].length) ?
683                         apply_bytes : md->sg_data[i].length;
684
685                 if (!sk_wmem_schedule(sk, size)) {
686                         if (!copied)
687                                 err = -ENOMEM;
688                         break;
689                 }
690
691                 sk_mem_charge(sk, size);
692                 r->sg_data[i] = md->sg_data[i];
693                 r->sg_data[i].length = size;
694                 md->sg_data[i].length -= size;
695                 md->sg_data[i].offset += size;
696                 copied += size;
697
698                 if (md->sg_data[i].length) {
699                         get_page(sg_page(&r->sg_data[i]));
700                         r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
701                 } else {
702                         i++;
703                         if (i == MAX_SKB_FRAGS)
704                                 i = 0;
705                         r->sg_end = i;
706                 }
707
708                 if (apply) {
709                         apply_bytes -= size;
710                         if (!apply_bytes)
711                                 break;
712                 }
713         } while (i != md->sg_end);
714
715         md->sg_start = i;
716
717         if (!err) {
718                 list_add_tail(&r->list, &psock->ingress);
719                 sk->sk_data_ready(sk);
720         } else {
721                 free_start_sg(sk, r);
722                 kfree(r);
723         }
724
725         release_sock(sk);
726         return err;
727 }
728
729 static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
730                                        struct sk_msg_buff *md,
731                                        int flags)
732 {
733         bool ingress = !!(md->flags & BPF_F_INGRESS);
734         struct smap_psock *psock;
735         int err = 0;
736
737         rcu_read_lock();
738         psock = smap_psock_sk(sk);
739         if (unlikely(!psock))
740                 goto out_rcu;
741
742         if (!refcount_inc_not_zero(&psock->refcnt))
743                 goto out_rcu;
744
745         rcu_read_unlock();
746
747         if (ingress) {
748                 err = bpf_tcp_ingress(sk, send, psock, md, flags);
749         } else {
750                 lock_sock(sk);
751                 err = bpf_tcp_push(sk, send, md, flags, false);
752                 release_sock(sk);
753         }
754         smap_release_sock(psock, sk);
755         if (unlikely(err))
756                 goto out;
757         return 0;
758 out_rcu:
759         rcu_read_unlock();
760 out:
761         free_bytes_sg(NULL, send, md, false);
762         return err;
763 }
764
765 static inline void bpf_md_init(struct smap_psock *psock)
766 {
767         if (!psock->apply_bytes) {
768                 psock->eval =  __SK_NONE;
769                 if (psock->sk_redir) {
770                         sock_put(psock->sk_redir);
771                         psock->sk_redir = NULL;
772                 }
773         }
774 }
775
776 static void apply_bytes_dec(struct smap_psock *psock, int i)
777 {
778         if (psock->apply_bytes) {
779                 if (psock->apply_bytes < i)
780                         psock->apply_bytes = 0;
781                 else
782                         psock->apply_bytes -= i;
783         }
784 }
785
786 static int bpf_exec_tx_verdict(struct smap_psock *psock,
787                                struct sk_msg_buff *m,
788                                struct sock *sk,
789                                int *copied, int flags)
790 {
791         bool cork = false, enospc = (m->sg_start == m->sg_end);
792         struct sock *redir;
793         int err = 0;
794         int send;
795
796 more_data:
797         if (psock->eval == __SK_NONE)
798                 psock->eval = smap_do_tx_msg(sk, psock, m);
799
800         if (m->cork_bytes &&
801             m->cork_bytes > psock->sg_size && !enospc) {
802                 psock->cork_bytes = m->cork_bytes - psock->sg_size;
803                 if (!psock->cork) {
804                         psock->cork = kcalloc(1,
805                                         sizeof(struct sk_msg_buff),
806                                         GFP_ATOMIC | __GFP_NOWARN);
807
808                         if (!psock->cork) {
809                                 err = -ENOMEM;
810                                 goto out_err;
811                         }
812                 }
813                 memcpy(psock->cork, m, sizeof(*m));
814                 goto out_err;
815         }
816
817         send = psock->sg_size;
818         if (psock->apply_bytes && psock->apply_bytes < send)
819                 send = psock->apply_bytes;
820
821         switch (psock->eval) {
822         case __SK_PASS:
823                 err = bpf_tcp_push(sk, send, m, flags, true);
824                 if (unlikely(err)) {
825                         *copied -= free_start_sg(sk, m);
826                         break;
827                 }
828
829                 apply_bytes_dec(psock, send);
830                 psock->sg_size -= send;
831                 break;
832         case __SK_REDIRECT:
833                 redir = psock->sk_redir;
834                 apply_bytes_dec(psock, send);
835
836                 if (psock->cork) {
837                         cork = true;
838                         psock->cork = NULL;
839                 }
840
841                 return_mem_sg(sk, send, m);
842                 release_sock(sk);
843
844                 err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
845                 lock_sock(sk);
846
847                 if (unlikely(err < 0)) {
848                         free_start_sg(sk, m);
849                         psock->sg_size = 0;
850                         if (!cork)
851                                 *copied -= send;
852                 } else {
853                         psock->sg_size -= send;
854                 }
855
856                 if (cork) {
857                         free_start_sg(sk, m);
858                         psock->sg_size = 0;
859                         kfree(m);
860                         m = NULL;
861                         err = 0;
862                 }
863                 break;
864         case __SK_DROP:
865         default:
866                 free_bytes_sg(sk, send, m, true);
867                 apply_bytes_dec(psock, send);
868                 *copied -= send;
869                 psock->sg_size -= send;
870                 err = -EACCES;
871                 break;
872         }
873
874         if (likely(!err)) {
875                 bpf_md_init(psock);
876                 if (m &&
877                     m->sg_data[m->sg_start].page_link &&
878                     m->sg_data[m->sg_start].length)
879                         goto more_data;
880         }
881
882 out_err:
883         return err;
884 }
885
886 static int bpf_wait_data(struct sock *sk,
887                          struct smap_psock *psk, int flags,
888                          long timeo, int *err)
889 {
890         int rc;
891
892         DEFINE_WAIT_FUNC(wait, woken_wake_function);
893
894         add_wait_queue(sk_sleep(sk), &wait);
895         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
896         rc = sk_wait_event(sk, &timeo,
897                            !list_empty(&psk->ingress) ||
898                            !skb_queue_empty(&sk->sk_receive_queue),
899                            &wait);
900         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
901         remove_wait_queue(sk_sleep(sk), &wait);
902
903         return rc;
904 }
905
906 static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
907                            int nonblock, int flags, int *addr_len)
908 {
909         struct iov_iter *iter = &msg->msg_iter;
910         struct smap_psock *psock;
911         int copied = 0;
912
913         if (unlikely(flags & MSG_ERRQUEUE))
914                 return inet_recv_error(sk, msg, len, addr_len);
915
916         rcu_read_lock();
917         psock = smap_psock_sk(sk);
918         if (unlikely(!psock))
919                 goto out;
920
921         if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
922                 goto out;
923         rcu_read_unlock();
924
925         if (!skb_queue_empty(&sk->sk_receive_queue))
926                 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
927
928         lock_sock(sk);
929 bytes_ready:
930         while (copied != len) {
931                 struct scatterlist *sg;
932                 struct sk_msg_buff *md;
933                 int i;
934
935                 md = list_first_entry_or_null(&psock->ingress,
936                                               struct sk_msg_buff, list);
937                 if (unlikely(!md))
938                         break;
939                 i = md->sg_start;
940                 do {
941                         struct page *page;
942                         int n, copy;
943
944                         sg = &md->sg_data[i];
945                         copy = sg->length;
946                         page = sg_page(sg);
947
948                         if (copied + copy > len)
949                                 copy = len - copied;
950
951                         n = copy_page_to_iter(page, sg->offset, copy, iter);
952                         if (n != copy) {
953                                 md->sg_start = i;
954                                 release_sock(sk);
955                                 smap_release_sock(psock, sk);
956                                 return -EFAULT;
957                         }
958
959                         copied += copy;
960                         sg->offset += copy;
961                         sg->length -= copy;
962                         sk_mem_uncharge(sk, copy);
963
964                         if (!sg->length) {
965                                 i++;
966                                 if (i == MAX_SKB_FRAGS)
967                                         i = 0;
968                                 if (!md->skb)
969                                         put_page(page);
970                         }
971                         if (copied == len)
972                                 break;
973                 } while (i != md->sg_end);
974                 md->sg_start = i;
975
976                 if (!sg->length && md->sg_start == md->sg_end) {
977                         list_del(&md->list);
978                         if (md->skb)
979                                 consume_skb(md->skb);
980                         kfree(md);
981                 }
982         }
983
984         if (!copied) {
985                 long timeo;
986                 int data;
987                 int err = 0;
988
989                 timeo = sock_rcvtimeo(sk, nonblock);
990                 data = bpf_wait_data(sk, psock, flags, timeo, &err);
991
992                 if (data) {
993                         if (!skb_queue_empty(&sk->sk_receive_queue)) {
994                                 release_sock(sk);
995                                 smap_release_sock(psock, sk);
996                                 copied = tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
997                                 return copied;
998                         }
999                         goto bytes_ready;
1000                 }
1001
1002                 if (err)
1003                         copied = err;
1004         }
1005
1006         release_sock(sk);
1007         smap_release_sock(psock, sk);
1008         return copied;
1009 out:
1010         rcu_read_unlock();
1011         return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1012 }
1013
1014
1015 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
1016 {
1017         int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
1018         struct sk_msg_buff md = {0};
1019         unsigned int sg_copy = 0;
1020         struct smap_psock *psock;
1021         int copied = 0, err = 0;
1022         struct scatterlist *sg;
1023         long timeo;
1024
1025         /* Its possible a sock event or user removed the psock _but_ the ops
1026          * have not been reprogrammed yet so we get here. In this case fallback
1027          * to tcp_sendmsg. Note this only works because we _only_ ever allow
1028          * a single ULP there is no hierarchy here.
1029          */
1030         rcu_read_lock();
1031         psock = smap_psock_sk(sk);
1032         if (unlikely(!psock)) {
1033                 rcu_read_unlock();
1034                 return tcp_sendmsg(sk, msg, size);
1035         }
1036
1037         /* Increment the psock refcnt to ensure its not released while sending a
1038          * message. Required because sk lookup and bpf programs are used in
1039          * separate rcu critical sections. Its OK if we lose the map entry
1040          * but we can't lose the sock reference.
1041          */
1042         if (!refcount_inc_not_zero(&psock->refcnt)) {
1043                 rcu_read_unlock();
1044                 return tcp_sendmsg(sk, msg, size);
1045         }
1046
1047         sg = md.sg_data;
1048         sg_init_marker(sg, MAX_SKB_FRAGS);
1049         rcu_read_unlock();
1050
1051         lock_sock(sk);
1052         timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1053
1054         while (msg_data_left(msg)) {
1055                 struct sk_msg_buff *m = NULL;
1056                 bool enospc = false;
1057                 int copy;
1058
1059                 if (sk->sk_err) {
1060                         err = -sk->sk_err;
1061                         goto out_err;
1062                 }
1063
1064                 copy = msg_data_left(msg);
1065                 if (!sk_stream_memory_free(sk))
1066                         goto wait_for_sndbuf;
1067
1068                 m = psock->cork_bytes ? psock->cork : &md;
1069                 m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
1070                 err = sk_alloc_sg(sk, copy, m->sg_data,
1071                                   m->sg_start, &m->sg_end, &sg_copy,
1072                                   m->sg_end - 1);
1073                 if (err) {
1074                         if (err != -ENOSPC)
1075                                 goto wait_for_memory;
1076                         enospc = true;
1077                         copy = sg_copy;
1078                 }
1079
1080                 err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
1081                 if (err < 0) {
1082                         free_curr_sg(sk, m);
1083                         goto out_err;
1084                 }
1085
1086                 psock->sg_size += copy;
1087                 copied += copy;
1088                 sg_copy = 0;
1089
1090                 /* When bytes are being corked skip running BPF program and
1091                  * applying verdict unless there is no more buffer space. In
1092                  * the ENOSPC case simply run BPF prorgram with currently
1093                  * accumulated data. We don't have much choice at this point
1094                  * we could try extending the page frags or chaining complex
1095                  * frags but even in these cases _eventually_ we will hit an
1096                  * OOM scenario. More complex recovery schemes may be
1097                  * implemented in the future, but BPF programs must handle
1098                  * the case where apply_cork requests are not honored. The
1099                  * canonical method to verify this is to check data length.
1100                  */
1101                 if (psock->cork_bytes) {
1102                         if (copy > psock->cork_bytes)
1103                                 psock->cork_bytes = 0;
1104                         else
1105                                 psock->cork_bytes -= copy;
1106
1107                         if (psock->cork_bytes && !enospc)
1108                                 goto out_cork;
1109
1110                         /* All cork bytes accounted for re-run filter */
1111                         psock->eval = __SK_NONE;
1112                         psock->cork_bytes = 0;
1113                 }
1114
1115                 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1116                 if (unlikely(err < 0))
1117                         goto out_err;
1118                 continue;
1119 wait_for_sndbuf:
1120                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1121 wait_for_memory:
1122                 err = sk_stream_wait_memory(sk, &timeo);
1123                 if (err) {
1124                         if (m && m != psock->cork)
1125                                 free_start_sg(sk, m);
1126                         goto out_err;
1127                 }
1128         }
1129 out_err:
1130         if (err < 0)
1131                 err = sk_stream_error(sk, msg->msg_flags, err);
1132 out_cork:
1133         release_sock(sk);
1134         smap_release_sock(psock, sk);
1135         return copied ? copied : err;
1136 }
1137
1138 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
1139                             int offset, size_t size, int flags)
1140 {
1141         struct sk_msg_buff md = {0}, *m = NULL;
1142         int err = 0, copied = 0;
1143         struct smap_psock *psock;
1144         struct scatterlist *sg;
1145         bool enospc = false;
1146
1147         rcu_read_lock();
1148         psock = smap_psock_sk(sk);
1149         if (unlikely(!psock))
1150                 goto accept;
1151
1152         if (!refcount_inc_not_zero(&psock->refcnt))
1153                 goto accept;
1154         rcu_read_unlock();
1155
1156         lock_sock(sk);
1157
1158         if (psock->cork_bytes) {
1159                 m = psock->cork;
1160                 sg = &m->sg_data[m->sg_end];
1161         } else {
1162                 m = &md;
1163                 sg = m->sg_data;
1164                 sg_init_marker(sg, MAX_SKB_FRAGS);
1165         }
1166
1167         /* Catch case where ring is full and sendpage is stalled. */
1168         if (unlikely(m->sg_end == m->sg_start &&
1169             m->sg_data[m->sg_end].length))
1170                 goto out_err;
1171
1172         psock->sg_size += size;
1173         sg_set_page(sg, page, size, offset);
1174         get_page(page);
1175         m->sg_copy[m->sg_end] = true;
1176         sk_mem_charge(sk, size);
1177         m->sg_end++;
1178         copied = size;
1179
1180         if (m->sg_end == MAX_SKB_FRAGS)
1181                 m->sg_end = 0;
1182
1183         if (m->sg_end == m->sg_start)
1184                 enospc = true;
1185
1186         if (psock->cork_bytes) {
1187                 if (size > psock->cork_bytes)
1188                         psock->cork_bytes = 0;
1189                 else
1190                         psock->cork_bytes -= size;
1191
1192                 if (psock->cork_bytes && !enospc)
1193                         goto out_err;
1194
1195                 /* All cork bytes accounted for re-run filter */
1196                 psock->eval = __SK_NONE;
1197                 psock->cork_bytes = 0;
1198         }
1199
1200         err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1201 out_err:
1202         release_sock(sk);
1203         smap_release_sock(psock, sk);
1204         return copied ? copied : err;
1205 accept:
1206         rcu_read_unlock();
1207         return tcp_sendpage(sk, page, offset, size, flags);
1208 }
1209
1210 static void bpf_tcp_msg_add(struct smap_psock *psock,
1211                             struct sock *sk,
1212                             struct bpf_prog *tx_msg)
1213 {
1214         struct bpf_prog *orig_tx_msg;
1215
1216         orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
1217         if (orig_tx_msg)
1218                 bpf_prog_put(orig_tx_msg);
1219 }
1220
1221 static int bpf_tcp_ulp_register(void)
1222 {
1223         build_protos(bpf_tcp_prots[SOCKMAP_IPV4], &tcp_prot);
1224         /* Once BPF TX ULP is registered it is never unregistered. It
1225          * will be in the ULP list for the lifetime of the system. Doing
1226          * duplicate registers is not a problem.
1227          */
1228         return tcp_register_ulp(&bpf_tcp_ulp_ops);
1229 }
1230
1231 static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
1232 {
1233         struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
1234         int rc;
1235
1236         if (unlikely(!prog))
1237                 return __SK_DROP;
1238
1239         skb_orphan(skb);
1240         /* We need to ensure that BPF metadata for maps is also cleared
1241          * when we orphan the skb so that we don't have the possibility
1242          * to reference a stale map.
1243          */
1244         TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
1245         skb->sk = psock->sock;
1246         bpf_compute_data_end_sk_skb(skb);
1247         preempt_disable();
1248         rc = (*prog->bpf_func)(skb, prog->insnsi);
1249         preempt_enable();
1250         skb->sk = NULL;
1251
1252         /* Moving return codes from UAPI namespace into internal namespace */
1253         return rc == SK_PASS ?
1254                 (TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
1255                 __SK_DROP;
1256 }
1257
1258 static int smap_do_ingress(struct smap_psock *psock, struct sk_buff *skb)
1259 {
1260         struct sock *sk = psock->sock;
1261         int copied = 0, num_sg;
1262         struct sk_msg_buff *r;
1263
1264         r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_ATOMIC);
1265         if (unlikely(!r))
1266                 return -EAGAIN;
1267
1268         if (!sk_rmem_schedule(sk, skb, skb->len)) {
1269                 kfree(r);
1270                 return -EAGAIN;
1271         }
1272
1273         sg_init_table(r->sg_data, MAX_SKB_FRAGS);
1274         num_sg = skb_to_sgvec(skb, r->sg_data, 0, skb->len);
1275         if (unlikely(num_sg < 0)) {
1276                 kfree(r);
1277                 return num_sg;
1278         }
1279         sk_mem_charge(sk, skb->len);
1280         copied = skb->len;
1281         r->sg_start = 0;
1282         r->sg_end = num_sg == MAX_SKB_FRAGS ? 0 : num_sg;
1283         r->skb = skb;
1284         list_add_tail(&r->list, &psock->ingress);
1285         sk->sk_data_ready(sk);
1286         return copied;
1287 }
1288
1289 static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
1290 {
1291         struct smap_psock *peer;
1292         struct sock *sk;
1293         __u32 in;
1294         int rc;
1295
1296         rc = smap_verdict_func(psock, skb);
1297         switch (rc) {
1298         case __SK_REDIRECT:
1299                 sk = do_sk_redirect_map(skb);
1300                 if (!sk) {
1301                         kfree_skb(skb);
1302                         break;
1303                 }
1304
1305                 peer = smap_psock_sk(sk);
1306                 in = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1307
1308                 if (unlikely(!peer || sock_flag(sk, SOCK_DEAD) ||
1309                              !test_bit(SMAP_TX_RUNNING, &peer->state))) {
1310                         kfree_skb(skb);
1311                         break;
1312                 }
1313
1314                 if (!in && sock_writeable(sk)) {
1315                         skb_set_owner_w(skb, sk);
1316                         skb_queue_tail(&peer->rxqueue, skb);
1317                         schedule_work(&peer->tx_work);
1318                         break;
1319                 } else if (in &&
1320                            atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) {
1321                         skb_queue_tail(&peer->rxqueue, skb);
1322                         schedule_work(&peer->tx_work);
1323                         break;
1324                 }
1325         /* Fall through and free skb otherwise */
1326         case __SK_DROP:
1327         default:
1328                 kfree_skb(skb);
1329         }
1330 }
1331
1332 static void smap_report_sk_error(struct smap_psock *psock, int err)
1333 {
1334         struct sock *sk = psock->sock;
1335
1336         sk->sk_err = err;
1337         sk->sk_error_report(sk);
1338 }
1339
1340 static void smap_read_sock_strparser(struct strparser *strp,
1341                                      struct sk_buff *skb)
1342 {
1343         struct smap_psock *psock;
1344
1345         rcu_read_lock();
1346         psock = container_of(strp, struct smap_psock, strp);
1347         smap_do_verdict(psock, skb);
1348         rcu_read_unlock();
1349 }
1350
1351 /* Called with lock held on socket */
1352 static void smap_data_ready(struct sock *sk)
1353 {
1354         struct smap_psock *psock;
1355
1356         rcu_read_lock();
1357         psock = smap_psock_sk(sk);
1358         if (likely(psock)) {
1359                 write_lock_bh(&sk->sk_callback_lock);
1360                 strp_data_ready(&psock->strp);
1361                 write_unlock_bh(&sk->sk_callback_lock);
1362         }
1363         rcu_read_unlock();
1364 }
1365
1366 static void smap_tx_work(struct work_struct *w)
1367 {
1368         struct smap_psock *psock;
1369         struct sk_buff *skb;
1370         int rem, off, n;
1371
1372         psock = container_of(w, struct smap_psock, tx_work);
1373
1374         /* lock sock to avoid losing sk_socket at some point during loop */
1375         lock_sock(psock->sock);
1376         if (psock->save_skb) {
1377                 skb = psock->save_skb;
1378                 rem = psock->save_rem;
1379                 off = psock->save_off;
1380                 psock->save_skb = NULL;
1381                 goto start;
1382         }
1383
1384         while ((skb = skb_dequeue(&psock->rxqueue))) {
1385                 __u32 flags;
1386
1387                 rem = skb->len;
1388                 off = 0;
1389 start:
1390                 flags = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1391                 do {
1392                         if (likely(psock->sock->sk_socket)) {
1393                                 if (flags)
1394                                         n = smap_do_ingress(psock, skb);
1395                                 else
1396                                         n = skb_send_sock_locked(psock->sock,
1397                                                                  skb, off, rem);
1398                         } else {
1399                                 n = -EINVAL;
1400                         }
1401
1402                         if (n <= 0) {
1403                                 if (n == -EAGAIN) {
1404                                         /* Retry when space is available */
1405                                         psock->save_skb = skb;
1406                                         psock->save_rem = rem;
1407                                         psock->save_off = off;
1408                                         goto out;
1409                                 }
1410                                 /* Hard errors break pipe and stop xmit */
1411                                 smap_report_sk_error(psock, n ? -n : EPIPE);
1412                                 clear_bit(SMAP_TX_RUNNING, &psock->state);
1413                                 kfree_skb(skb);
1414                                 goto out;
1415                         }
1416                         rem -= n;
1417                         off += n;
1418                 } while (rem);
1419
1420                 if (!flags)
1421                         kfree_skb(skb);
1422         }
1423 out:
1424         release_sock(psock->sock);
1425 }
1426
1427 static void smap_write_space(struct sock *sk)
1428 {
1429         struct smap_psock *psock;
1430
1431         rcu_read_lock();
1432         psock = smap_psock_sk(sk);
1433         if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
1434                 schedule_work(&psock->tx_work);
1435         rcu_read_unlock();
1436 }
1437
1438 static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
1439 {
1440         if (!psock->strp_enabled)
1441                 return;
1442         sk->sk_data_ready = psock->save_data_ready;
1443         sk->sk_write_space = psock->save_write_space;
1444         psock->save_data_ready = NULL;
1445         psock->save_write_space = NULL;
1446         strp_stop(&psock->strp);
1447         psock->strp_enabled = false;
1448 }
1449
1450 static void smap_destroy_psock(struct rcu_head *rcu)
1451 {
1452         struct smap_psock *psock = container_of(rcu,
1453                                                   struct smap_psock, rcu);
1454
1455         /* Now that a grace period has passed there is no longer
1456          * any reference to this sock in the sockmap so we can
1457          * destroy the psock, strparser, and bpf programs. But,
1458          * because we use workqueue sync operations we can not
1459          * do it in rcu context
1460          */
1461         schedule_work(&psock->gc_work);
1462 }
1463
1464 static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
1465 {
1466         if (refcount_dec_and_test(&psock->refcnt)) {
1467                 tcp_cleanup_ulp(sock);
1468                 write_lock_bh(&sock->sk_callback_lock);
1469                 smap_stop_sock(psock, sock);
1470                 write_unlock_bh(&sock->sk_callback_lock);
1471                 clear_bit(SMAP_TX_RUNNING, &psock->state);
1472                 rcu_assign_sk_user_data(sock, NULL);
1473                 call_rcu_sched(&psock->rcu, smap_destroy_psock);
1474         }
1475 }
1476
1477 static int smap_parse_func_strparser(struct strparser *strp,
1478                                        struct sk_buff *skb)
1479 {
1480         struct smap_psock *psock;
1481         struct bpf_prog *prog;
1482         int rc;
1483
1484         rcu_read_lock();
1485         psock = container_of(strp, struct smap_psock, strp);
1486         prog = READ_ONCE(psock->bpf_parse);
1487
1488         if (unlikely(!prog)) {
1489                 rcu_read_unlock();
1490                 return skb->len;
1491         }
1492
1493         /* Attach socket for bpf program to use if needed we can do this
1494          * because strparser clones the skb before handing it to a upper
1495          * layer, meaning skb_orphan has been called. We NULL sk on the
1496          * way out to ensure we don't trigger a BUG_ON in skb/sk operations
1497          * later and because we are not charging the memory of this skb to
1498          * any socket yet.
1499          */
1500         skb->sk = psock->sock;
1501         bpf_compute_data_end_sk_skb(skb);
1502         rc = (*prog->bpf_func)(skb, prog->insnsi);
1503         skb->sk = NULL;
1504         rcu_read_unlock();
1505         return rc;
1506 }
1507
1508 static int smap_read_sock_done(struct strparser *strp, int err)
1509 {
1510         return err;
1511 }
1512
1513 static int smap_init_sock(struct smap_psock *psock,
1514                           struct sock *sk)
1515 {
1516         static const struct strp_callbacks cb = {
1517                 .rcv_msg = smap_read_sock_strparser,
1518                 .parse_msg = smap_parse_func_strparser,
1519                 .read_sock_done = smap_read_sock_done,
1520         };
1521
1522         return strp_init(&psock->strp, sk, &cb);
1523 }
1524
1525 static void smap_init_progs(struct smap_psock *psock,
1526                             struct bpf_prog *verdict,
1527                             struct bpf_prog *parse)
1528 {
1529         struct bpf_prog *orig_parse, *orig_verdict;
1530
1531         orig_parse = xchg(&psock->bpf_parse, parse);
1532         orig_verdict = xchg(&psock->bpf_verdict, verdict);
1533
1534         if (orig_verdict)
1535                 bpf_prog_put(orig_verdict);
1536         if (orig_parse)
1537                 bpf_prog_put(orig_parse);
1538 }
1539
1540 static void smap_start_sock(struct smap_psock *psock, struct sock *sk)
1541 {
1542         if (sk->sk_data_ready == smap_data_ready)
1543                 return;
1544         psock->save_data_ready = sk->sk_data_ready;
1545         psock->save_write_space = sk->sk_write_space;
1546         sk->sk_data_ready = smap_data_ready;
1547         sk->sk_write_space = smap_write_space;
1548         psock->strp_enabled = true;
1549 }
1550
1551 static void sock_map_remove_complete(struct bpf_stab *stab)
1552 {
1553         bpf_map_area_free(stab->sock_map);
1554         kfree(stab);
1555 }
1556
1557 static void smap_gc_work(struct work_struct *w)
1558 {
1559         struct smap_psock_map_entry *e, *tmp;
1560         struct sk_msg_buff *md, *mtmp;
1561         struct smap_psock *psock;
1562
1563         psock = container_of(w, struct smap_psock, gc_work);
1564
1565         /* no callback lock needed because we already detached sockmap ops */
1566         if (psock->strp_enabled)
1567                 strp_done(&psock->strp);
1568
1569         cancel_work_sync(&psock->tx_work);
1570         __skb_queue_purge(&psock->rxqueue);
1571
1572         /* At this point all strparser and xmit work must be complete */
1573         if (psock->bpf_parse)
1574                 bpf_prog_put(psock->bpf_parse);
1575         if (psock->bpf_verdict)
1576                 bpf_prog_put(psock->bpf_verdict);
1577         if (psock->bpf_tx_msg)
1578                 bpf_prog_put(psock->bpf_tx_msg);
1579
1580         if (psock->cork) {
1581                 free_start_sg(psock->sock, psock->cork);
1582                 kfree(psock->cork);
1583         }
1584
1585         list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
1586                 list_del(&md->list);
1587                 free_start_sg(psock->sock, md);
1588                 kfree(md);
1589         }
1590
1591         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1592                 list_del(&e->list);
1593                 kfree(e);
1594         }
1595
1596         if (psock->sk_redir)
1597                 sock_put(psock->sk_redir);
1598
1599         sock_put(psock->sock);
1600         kfree(psock);
1601 }
1602
1603 static struct smap_psock *smap_init_psock(struct sock *sock, int node)
1604 {
1605         struct smap_psock *psock;
1606
1607         psock = kzalloc_node(sizeof(struct smap_psock),
1608                              GFP_ATOMIC | __GFP_NOWARN,
1609                              node);
1610         if (!psock)
1611                 return ERR_PTR(-ENOMEM);
1612
1613         psock->eval =  __SK_NONE;
1614         psock->sock = sock;
1615         skb_queue_head_init(&psock->rxqueue);
1616         INIT_WORK(&psock->tx_work, smap_tx_work);
1617         INIT_WORK(&psock->gc_work, smap_gc_work);
1618         INIT_LIST_HEAD(&psock->maps);
1619         INIT_LIST_HEAD(&psock->ingress);
1620         refcount_set(&psock->refcnt, 1);
1621         spin_lock_init(&psock->maps_lock);
1622
1623         rcu_assign_sk_user_data(sock, psock);
1624         sock_hold(sock);
1625         return psock;
1626 }
1627
1628 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1629 {
1630         struct bpf_stab *stab;
1631         u64 cost;
1632         int err;
1633
1634         if (!capable(CAP_NET_ADMIN))
1635                 return ERR_PTR(-EPERM);
1636
1637         /* check sanity of attributes */
1638         if (attr->max_entries == 0 || attr->key_size != 4 ||
1639             attr->value_size != 4 || attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1640                 return ERR_PTR(-EINVAL);
1641
1642         err = bpf_tcp_ulp_register();
1643         if (err && err != -EEXIST)
1644                 return ERR_PTR(err);
1645
1646         stab = kzalloc(sizeof(*stab), GFP_USER);
1647         if (!stab)
1648                 return ERR_PTR(-ENOMEM);
1649
1650         bpf_map_init_from_attr(&stab->map, attr);
1651         raw_spin_lock_init(&stab->lock);
1652
1653         /* make sure page count doesn't overflow */
1654         cost = (u64) stab->map.max_entries * sizeof(struct sock *);
1655         err = -EINVAL;
1656         if (cost >= U32_MAX - PAGE_SIZE)
1657                 goto free_stab;
1658
1659         stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
1660
1661         /* if map size is larger than memlock limit, reject it early */
1662         err = bpf_map_precharge_memlock(stab->map.pages);
1663         if (err)
1664                 goto free_stab;
1665
1666         err = -ENOMEM;
1667         stab->sock_map = bpf_map_area_alloc(stab->map.max_entries *
1668                                             sizeof(struct sock *),
1669                                             stab->map.numa_node);
1670         if (!stab->sock_map)
1671                 goto free_stab;
1672
1673         return &stab->map;
1674 free_stab:
1675         kfree(stab);
1676         return ERR_PTR(err);
1677 }
1678
1679 static void smap_list_map_remove(struct smap_psock *psock,
1680                                  struct sock **entry)
1681 {
1682         struct smap_psock_map_entry *e, *tmp;
1683
1684         spin_lock_bh(&psock->maps_lock);
1685         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1686                 if (e->entry == entry) {
1687                         list_del(&e->list);
1688                         kfree(e);
1689                 }
1690         }
1691         spin_unlock_bh(&psock->maps_lock);
1692 }
1693
1694 static void smap_list_hash_remove(struct smap_psock *psock,
1695                                   struct htab_elem *hash_link)
1696 {
1697         struct smap_psock_map_entry *e, *tmp;
1698
1699         spin_lock_bh(&psock->maps_lock);
1700         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1701                 struct htab_elem *c = rcu_dereference(e->hash_link);
1702
1703                 if (c == hash_link) {
1704                         list_del(&e->list);
1705                         kfree(e);
1706                 }
1707         }
1708         spin_unlock_bh(&psock->maps_lock);
1709 }
1710
1711 static void sock_map_free(struct bpf_map *map)
1712 {
1713         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1714         int i;
1715
1716         synchronize_rcu();
1717
1718         /* At this point no update, lookup or delete operations can happen.
1719          * However, be aware we can still get a socket state event updates,
1720          * and data ready callabacks that reference the psock from sk_user_data
1721          * Also psock worker threads are still in-flight. So smap_release_sock
1722          * will only free the psock after cancel_sync on the worker threads
1723          * and a grace period expire to ensure psock is really safe to remove.
1724          */
1725         rcu_read_lock();
1726         raw_spin_lock_bh(&stab->lock);
1727         for (i = 0; i < stab->map.max_entries; i++) {
1728                 struct smap_psock *psock;
1729                 struct sock *sock;
1730
1731                 sock = stab->sock_map[i];
1732                 if (!sock)
1733                         continue;
1734                 stab->sock_map[i] = NULL;
1735                 psock = smap_psock_sk(sock);
1736                 /* This check handles a racing sock event that can get the
1737                  * sk_callback_lock before this case but after xchg happens
1738                  * causing the refcnt to hit zero and sock user data (psock)
1739                  * to be null and queued for garbage collection.
1740                  */
1741                 if (likely(psock)) {
1742                         smap_list_map_remove(psock, &stab->sock_map[i]);
1743                         smap_release_sock(psock, sock);
1744                 }
1745         }
1746         raw_spin_unlock_bh(&stab->lock);
1747         rcu_read_unlock();
1748
1749         sock_map_remove_complete(stab);
1750 }
1751
1752 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
1753 {
1754         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1755         u32 i = key ? *(u32 *)key : U32_MAX;
1756         u32 *next = (u32 *)next_key;
1757
1758         if (i >= stab->map.max_entries) {
1759                 *next = 0;
1760                 return 0;
1761         }
1762
1763         if (i == stab->map.max_entries - 1)
1764                 return -ENOENT;
1765
1766         *next = i + 1;
1767         return 0;
1768 }
1769
1770 struct sock  *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
1771 {
1772         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1773
1774         if (key >= map->max_entries)
1775                 return NULL;
1776
1777         return READ_ONCE(stab->sock_map[key]);
1778 }
1779
1780 static int sock_map_delete_elem(struct bpf_map *map, void *key)
1781 {
1782         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1783         struct smap_psock *psock;
1784         int k = *(u32 *)key;
1785         struct sock *sock;
1786
1787         if (k >= map->max_entries)
1788                 return -EINVAL;
1789
1790         raw_spin_lock_bh(&stab->lock);
1791         sock = stab->sock_map[k];
1792         stab->sock_map[k] = NULL;
1793         raw_spin_unlock_bh(&stab->lock);
1794         if (!sock)
1795                 return -EINVAL;
1796
1797         psock = smap_psock_sk(sock);
1798         if (!psock)
1799                 return 0;
1800         if (psock->bpf_parse) {
1801                 write_lock_bh(&sock->sk_callback_lock);
1802                 smap_stop_sock(psock, sock);
1803                 write_unlock_bh(&sock->sk_callback_lock);
1804         }
1805         smap_list_map_remove(psock, &stab->sock_map[k]);
1806         smap_release_sock(psock, sock);
1807         return 0;
1808 }
1809
1810 /* Locking notes: Concurrent updates, deletes, and lookups are allowed and are
1811  * done inside rcu critical sections. This ensures on updates that the psock
1812  * will not be released via smap_release_sock() until concurrent updates/deletes
1813  * complete. All operations operate on sock_map using cmpxchg and xchg
1814  * operations to ensure we do not get stale references. Any reads into the
1815  * map must be done with READ_ONCE() because of this.
1816  *
1817  * A psock is destroyed via call_rcu and after any worker threads are cancelled
1818  * and syncd so we are certain all references from the update/lookup/delete
1819  * operations as well as references in the data path are no longer in use.
1820  *
1821  * Psocks may exist in multiple maps, but only a single set of parse/verdict
1822  * programs may be inherited from the maps it belongs to. A reference count
1823  * is kept with the total number of references to the psock from all maps. The
1824  * psock will not be released until this reaches zero. The psock and sock
1825  * user data data use the sk_callback_lock to protect critical data structures
1826  * from concurrent access. This allows us to avoid two updates from modifying
1827  * the user data in sock and the lock is required anyways for modifying
1828  * callbacks, we simply increase its scope slightly.
1829  *
1830  * Rules to follow,
1831  *  - psock must always be read inside RCU critical section
1832  *  - sk_user_data must only be modified inside sk_callback_lock and read
1833  *    inside RCU critical section.
1834  *  - psock->maps list must only be read & modified inside sk_callback_lock
1835  *  - sock_map must use READ_ONCE and (cmp)xchg operations
1836  *  - BPF verdict/parse programs must use READ_ONCE and xchg operations
1837  */
1838
1839 static int __sock_map_ctx_update_elem(struct bpf_map *map,
1840                                       struct bpf_sock_progs *progs,
1841                                       struct sock *sock,
1842                                       void *key)
1843 {
1844         struct bpf_prog *verdict, *parse, *tx_msg;
1845         struct smap_psock *psock;
1846         bool new = false;
1847         int err = 0;
1848
1849         /* 1. If sock map has BPF programs those will be inherited by the
1850          * sock being added. If the sock is already attached to BPF programs
1851          * this results in an error.
1852          */
1853         verdict = READ_ONCE(progs->bpf_verdict);
1854         parse = READ_ONCE(progs->bpf_parse);
1855         tx_msg = READ_ONCE(progs->bpf_tx_msg);
1856
1857         if (parse && verdict) {
1858                 /* bpf prog refcnt may be zero if a concurrent attach operation
1859                  * removes the program after the above READ_ONCE() but before
1860                  * we increment the refcnt. If this is the case abort with an
1861                  * error.
1862                  */
1863                 verdict = bpf_prog_inc_not_zero(verdict);
1864                 if (IS_ERR(verdict))
1865                         return PTR_ERR(verdict);
1866
1867                 parse = bpf_prog_inc_not_zero(parse);
1868                 if (IS_ERR(parse)) {
1869                         bpf_prog_put(verdict);
1870                         return PTR_ERR(parse);
1871                 }
1872         }
1873
1874         if (tx_msg) {
1875                 tx_msg = bpf_prog_inc_not_zero(tx_msg);
1876                 if (IS_ERR(tx_msg)) {
1877                         if (parse && verdict) {
1878                                 bpf_prog_put(parse);
1879                                 bpf_prog_put(verdict);
1880                         }
1881                         return PTR_ERR(tx_msg);
1882                 }
1883         }
1884
1885         psock = smap_psock_sk(sock);
1886
1887         /* 2. Do not allow inheriting programs if psock exists and has
1888          * already inherited programs. This would create confusion on
1889          * which parser/verdict program is running. If no psock exists
1890          * create one. Inside sk_callback_lock to ensure concurrent create
1891          * doesn't update user data.
1892          */
1893         if (psock) {
1894                 if (READ_ONCE(psock->bpf_parse) && parse) {
1895                         err = -EBUSY;
1896                         goto out_progs;
1897                 }
1898                 if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
1899                         err = -EBUSY;
1900                         goto out_progs;
1901                 }
1902                 if (!refcount_inc_not_zero(&psock->refcnt)) {
1903                         err = -EAGAIN;
1904                         goto out_progs;
1905                 }
1906         } else {
1907                 psock = smap_init_psock(sock, map->numa_node);
1908                 if (IS_ERR(psock)) {
1909                         err = PTR_ERR(psock);
1910                         goto out_progs;
1911                 }
1912
1913                 set_bit(SMAP_TX_RUNNING, &psock->state);
1914                 new = true;
1915         }
1916
1917         /* 3. At this point we have a reference to a valid psock that is
1918          * running. Attach any BPF programs needed.
1919          */
1920         if (tx_msg)
1921                 bpf_tcp_msg_add(psock, sock, tx_msg);
1922         if (new) {
1923                 err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
1924                 if (err)
1925                         goto out_free;
1926         }
1927
1928         if (parse && verdict && !psock->strp_enabled) {
1929                 err = smap_init_sock(psock, sock);
1930                 if (err)
1931                         goto out_free;
1932                 smap_init_progs(psock, verdict, parse);
1933                 write_lock_bh(&sock->sk_callback_lock);
1934                 smap_start_sock(psock, sock);
1935                 write_unlock_bh(&sock->sk_callback_lock);
1936         }
1937
1938         return err;
1939 out_free:
1940         smap_release_sock(psock, sock);
1941 out_progs:
1942         if (parse && verdict) {
1943                 bpf_prog_put(parse);
1944                 bpf_prog_put(verdict);
1945         }
1946         if (tx_msg)
1947                 bpf_prog_put(tx_msg);
1948         return err;
1949 }
1950
1951 static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1952                                     struct bpf_map *map,
1953                                     void *key, u64 flags)
1954 {
1955         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1956         struct bpf_sock_progs *progs = &stab->progs;
1957         struct sock *osock, *sock = skops->sk;
1958         struct smap_psock_map_entry *e;
1959         struct smap_psock *psock;
1960         u32 i = *(u32 *)key;
1961         int err;
1962
1963         if (unlikely(flags > BPF_EXIST))
1964                 return -EINVAL;
1965         if (unlikely(i >= stab->map.max_entries))
1966                 return -E2BIG;
1967
1968         e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
1969         if (!e)
1970                 return -ENOMEM;
1971
1972         err = __sock_map_ctx_update_elem(map, progs, sock, key);
1973         if (err)
1974                 goto out;
1975
1976         /* psock guaranteed to be present. */
1977         psock = smap_psock_sk(sock);
1978         raw_spin_lock_bh(&stab->lock);
1979         osock = stab->sock_map[i];
1980         if (osock && flags == BPF_NOEXIST) {
1981                 err = -EEXIST;
1982                 goto out_unlock;
1983         }
1984         if (!osock && flags == BPF_EXIST) {
1985                 err = -ENOENT;
1986                 goto out_unlock;
1987         }
1988
1989         e->entry = &stab->sock_map[i];
1990         e->map = map;
1991         spin_lock_bh(&psock->maps_lock);
1992         list_add_tail(&e->list, &psock->maps);
1993         spin_unlock_bh(&psock->maps_lock);
1994
1995         stab->sock_map[i] = sock;
1996         if (osock) {
1997                 psock = smap_psock_sk(osock);
1998                 smap_list_map_remove(psock, &stab->sock_map[i]);
1999                 smap_release_sock(psock, osock);
2000         }
2001         raw_spin_unlock_bh(&stab->lock);
2002         return 0;
2003 out_unlock:
2004         smap_release_sock(psock, sock);
2005         raw_spin_unlock_bh(&stab->lock);
2006 out:
2007         kfree(e);
2008         return err;
2009 }
2010
2011 int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
2012 {
2013         struct bpf_sock_progs *progs;
2014         struct bpf_prog *orig;
2015
2016         if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2017                 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2018
2019                 progs = &stab->progs;
2020         } else if (map->map_type == BPF_MAP_TYPE_SOCKHASH) {
2021                 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2022
2023                 progs = &htab->progs;
2024         } else {
2025                 return -EINVAL;
2026         }
2027
2028         switch (type) {
2029         case BPF_SK_MSG_VERDICT:
2030                 orig = xchg(&progs->bpf_tx_msg, prog);
2031                 break;
2032         case BPF_SK_SKB_STREAM_PARSER:
2033                 orig = xchg(&progs->bpf_parse, prog);
2034                 break;
2035         case BPF_SK_SKB_STREAM_VERDICT:
2036                 orig = xchg(&progs->bpf_verdict, prog);
2037                 break;
2038         default:
2039                 return -EOPNOTSUPP;
2040         }
2041
2042         if (orig)
2043                 bpf_prog_put(orig);
2044
2045         return 0;
2046 }
2047
2048 int sockmap_get_from_fd(const union bpf_attr *attr, int type,
2049                         struct bpf_prog *prog)
2050 {
2051         int ufd = attr->target_fd;
2052         struct bpf_map *map;
2053         struct fd f;
2054         int err;
2055
2056         f = fdget(ufd);
2057         map = __bpf_map_get(f);
2058         if (IS_ERR(map))
2059                 return PTR_ERR(map);
2060
2061         err = sock_map_prog(map, prog, attr->attach_type);
2062         fdput(f);
2063         return err;
2064 }
2065
2066 static void *sock_map_lookup(struct bpf_map *map, void *key)
2067 {
2068         return NULL;
2069 }
2070
2071 static int sock_map_update_elem(struct bpf_map *map,
2072                                 void *key, void *value, u64 flags)
2073 {
2074         struct bpf_sock_ops_kern skops;
2075         u32 fd = *(u32 *)value;
2076         struct socket *socket;
2077         int err;
2078
2079         socket = sockfd_lookup(fd, &err);
2080         if (!socket)
2081                 return err;
2082
2083         skops.sk = socket->sk;
2084         if (!skops.sk) {
2085                 fput(socket->file);
2086                 return -EINVAL;
2087         }
2088
2089         if (skops.sk->sk_type != SOCK_STREAM ||
2090             skops.sk->sk_protocol != IPPROTO_TCP) {
2091                 fput(socket->file);
2092                 return -EOPNOTSUPP;
2093         }
2094
2095         lock_sock(skops.sk);
2096         preempt_disable();
2097         rcu_read_lock();
2098         err = sock_map_ctx_update_elem(&skops, map, key, flags);
2099         rcu_read_unlock();
2100         preempt_enable();
2101         release_sock(skops.sk);
2102         fput(socket->file);
2103         return err;
2104 }
2105
2106 static void sock_map_release(struct bpf_map *map)
2107 {
2108         struct bpf_sock_progs *progs;
2109         struct bpf_prog *orig;
2110
2111         if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2112                 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2113
2114                 progs = &stab->progs;
2115         } else {
2116                 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2117
2118                 progs = &htab->progs;
2119         }
2120
2121         orig = xchg(&progs->bpf_parse, NULL);
2122         if (orig)
2123                 bpf_prog_put(orig);
2124         orig = xchg(&progs->bpf_verdict, NULL);
2125         if (orig)
2126                 bpf_prog_put(orig);
2127
2128         orig = xchg(&progs->bpf_tx_msg, NULL);
2129         if (orig)
2130                 bpf_prog_put(orig);
2131 }
2132
2133 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
2134 {
2135         struct bpf_htab *htab;
2136         int i, err;
2137         u64 cost;
2138
2139         if (!capable(CAP_NET_ADMIN))
2140                 return ERR_PTR(-EPERM);
2141
2142         /* check sanity of attributes */
2143         if (attr->max_entries == 0 || attr->value_size != 4 ||
2144             attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
2145                 return ERR_PTR(-EINVAL);
2146
2147         if (attr->key_size > MAX_BPF_STACK)
2148                 /* eBPF programs initialize keys on stack, so they cannot be
2149                  * larger than max stack size
2150                  */
2151                 return ERR_PTR(-E2BIG);
2152
2153         err = bpf_tcp_ulp_register();
2154         if (err && err != -EEXIST)
2155                 return ERR_PTR(err);
2156
2157         htab = kzalloc(sizeof(*htab), GFP_USER);
2158         if (!htab)
2159                 return ERR_PTR(-ENOMEM);
2160
2161         bpf_map_init_from_attr(&htab->map, attr);
2162
2163         htab->n_buckets = roundup_pow_of_two(htab->map.max_entries);
2164         htab->elem_size = sizeof(struct htab_elem) +
2165                           round_up(htab->map.key_size, 8);
2166         err = -EINVAL;
2167         if (htab->n_buckets == 0 ||
2168             htab->n_buckets > U32_MAX / sizeof(struct bucket))
2169                 goto free_htab;
2170
2171         cost = (u64) htab->n_buckets * sizeof(struct bucket) +
2172                (u64) htab->elem_size * htab->map.max_entries;
2173
2174         if (cost >= U32_MAX - PAGE_SIZE)
2175                 goto free_htab;
2176
2177         htab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
2178         err = bpf_map_precharge_memlock(htab->map.pages);
2179         if (err)
2180                 goto free_htab;
2181
2182         err = -ENOMEM;
2183         htab->buckets = bpf_map_area_alloc(
2184                                 htab->n_buckets * sizeof(struct bucket),
2185                                 htab->map.numa_node);
2186         if (!htab->buckets)
2187                 goto free_htab;
2188
2189         for (i = 0; i < htab->n_buckets; i++) {
2190                 INIT_HLIST_HEAD(&htab->buckets[i].head);
2191                 raw_spin_lock_init(&htab->buckets[i].lock);
2192         }
2193
2194         return &htab->map;
2195 free_htab:
2196         kfree(htab);
2197         return ERR_PTR(err);
2198 }
2199
2200 static void __bpf_htab_free(struct rcu_head *rcu)
2201 {
2202         struct bpf_htab *htab;
2203
2204         htab = container_of(rcu, struct bpf_htab, rcu);
2205         bpf_map_area_free(htab->buckets);
2206         kfree(htab);
2207 }
2208
2209 static void sock_hash_free(struct bpf_map *map)
2210 {
2211         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2212         int i;
2213
2214         synchronize_rcu();
2215
2216         /* At this point no update, lookup or delete operations can happen.
2217          * However, be aware we can still get a socket state event updates,
2218          * and data ready callabacks that reference the psock from sk_user_data
2219          * Also psock worker threads are still in-flight. So smap_release_sock
2220          * will only free the psock after cancel_sync on the worker threads
2221          * and a grace period expire to ensure psock is really safe to remove.
2222          */
2223         rcu_read_lock();
2224         for (i = 0; i < htab->n_buckets; i++) {
2225                 struct bucket *b = __select_bucket(htab, i);
2226                 struct hlist_head *head;
2227                 struct hlist_node *n;
2228                 struct htab_elem *l;
2229
2230                 raw_spin_lock_bh(&b->lock);
2231                 head = &b->head;
2232                 hlist_for_each_entry_safe(l, n, head, hash_node) {
2233                         struct sock *sock = l->sk;
2234                         struct smap_psock *psock;
2235
2236                         hlist_del_rcu(&l->hash_node);
2237                         psock = smap_psock_sk(sock);
2238                         /* This check handles a racing sock event that can get
2239                          * the sk_callback_lock before this case but after xchg
2240                          * causing the refcnt to hit zero and sock user data
2241                          * (psock) to be null and queued for garbage collection.
2242                          */
2243                         if (likely(psock)) {
2244                                 smap_list_hash_remove(psock, l);
2245                                 smap_release_sock(psock, sock);
2246                         }
2247                         free_htab_elem(htab, l);
2248                 }
2249                 raw_spin_unlock_bh(&b->lock);
2250         }
2251         rcu_read_unlock();
2252         call_rcu(&htab->rcu, __bpf_htab_free);
2253 }
2254
2255 static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
2256                                               void *key, u32 key_size, u32 hash,
2257                                               struct sock *sk,
2258                                               struct htab_elem *old_elem)
2259 {
2260         struct htab_elem *l_new;
2261
2262         if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
2263                 if (!old_elem) {
2264                         atomic_dec(&htab->count);
2265                         return ERR_PTR(-E2BIG);
2266                 }
2267         }
2268         l_new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
2269                              htab->map.numa_node);
2270         if (!l_new)
2271                 return ERR_PTR(-ENOMEM);
2272
2273         memcpy(l_new->key, key, key_size);
2274         l_new->sk = sk;
2275         l_new->hash = hash;
2276         return l_new;
2277 }
2278
2279 static inline u32 htab_map_hash(const void *key, u32 key_len)
2280 {
2281         return jhash(key, key_len, 0);
2282 }
2283
2284 static int sock_hash_get_next_key(struct bpf_map *map,
2285                                   void *key, void *next_key)
2286 {
2287         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2288         struct htab_elem *l, *next_l;
2289         struct hlist_head *h;
2290         u32 hash, key_size;
2291         int i = 0;
2292
2293         WARN_ON_ONCE(!rcu_read_lock_held());
2294
2295         key_size = map->key_size;
2296         if (!key)
2297                 goto find_first_elem;
2298         hash = htab_map_hash(key, key_size);
2299         h = select_bucket(htab, hash);
2300
2301         l = lookup_elem_raw(h, hash, key, key_size);
2302         if (!l)
2303                 goto find_first_elem;
2304         next_l = hlist_entry_safe(
2305                      rcu_dereference_raw(hlist_next_rcu(&l->hash_node)),
2306                      struct htab_elem, hash_node);
2307         if (next_l) {
2308                 memcpy(next_key, next_l->key, key_size);
2309                 return 0;
2310         }
2311
2312         /* no more elements in this hash list, go to the next bucket */
2313         i = hash & (htab->n_buckets - 1);
2314         i++;
2315
2316 find_first_elem:
2317         /* iterate over buckets */
2318         for (; i < htab->n_buckets; i++) {
2319                 h = select_bucket(htab, i);
2320
2321                 /* pick first element in the bucket */
2322                 next_l = hlist_entry_safe(
2323                                 rcu_dereference_raw(hlist_first_rcu(h)),
2324                                 struct htab_elem, hash_node);
2325                 if (next_l) {
2326                         /* if it's not empty, just return it */
2327                         memcpy(next_key, next_l->key, key_size);
2328                         return 0;
2329                 }
2330         }
2331
2332         /* iterated over all buckets and all elements */
2333         return -ENOENT;
2334 }
2335
2336 static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
2337                                      struct bpf_map *map,
2338                                      void *key, u64 map_flags)
2339 {
2340         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2341         struct bpf_sock_progs *progs = &htab->progs;
2342         struct htab_elem *l_new = NULL, *l_old;
2343         struct smap_psock_map_entry *e = NULL;
2344         struct hlist_head *head;
2345         struct smap_psock *psock;
2346         u32 key_size, hash;
2347         struct sock *sock;
2348         struct bucket *b;
2349         int err;
2350
2351         sock = skops->sk;
2352
2353         if (sock->sk_type != SOCK_STREAM ||
2354             sock->sk_protocol != IPPROTO_TCP)
2355                 return -EOPNOTSUPP;
2356
2357         if (unlikely(map_flags > BPF_EXIST))
2358                 return -EINVAL;
2359
2360         e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2361         if (!e)
2362                 return -ENOMEM;
2363
2364         WARN_ON_ONCE(!rcu_read_lock_held());
2365         key_size = map->key_size;
2366         hash = htab_map_hash(key, key_size);
2367         b = __select_bucket(htab, hash);
2368         head = &b->head;
2369
2370         err = __sock_map_ctx_update_elem(map, progs, sock, key);
2371         if (err)
2372                 goto err;
2373
2374         /* psock is valid here because otherwise above *ctx_update_elem would
2375          * have thrown an error. It is safe to skip error check.
2376          */
2377         psock = smap_psock_sk(sock);
2378         raw_spin_lock_bh(&b->lock);
2379         l_old = lookup_elem_raw(head, hash, key, key_size);
2380         if (l_old && map_flags == BPF_NOEXIST) {
2381                 err = -EEXIST;
2382                 goto bucket_err;
2383         }
2384         if (!l_old && map_flags == BPF_EXIST) {
2385                 err = -ENOENT;
2386                 goto bucket_err;
2387         }
2388
2389         l_new = alloc_sock_hash_elem(htab, key, key_size, hash, sock, l_old);
2390         if (IS_ERR(l_new)) {
2391                 err = PTR_ERR(l_new);
2392                 goto bucket_err;
2393         }
2394
2395         rcu_assign_pointer(e->hash_link, l_new);
2396         e->map = map;
2397         spin_lock_bh(&psock->maps_lock);
2398         list_add_tail(&e->list, &psock->maps);
2399         spin_unlock_bh(&psock->maps_lock);
2400
2401         /* add new element to the head of the list, so that
2402          * concurrent search will find it before old elem
2403          */
2404         hlist_add_head_rcu(&l_new->hash_node, head);
2405         if (l_old) {
2406                 psock = smap_psock_sk(l_old->sk);
2407
2408                 hlist_del_rcu(&l_old->hash_node);
2409                 smap_list_hash_remove(psock, l_old);
2410                 smap_release_sock(psock, l_old->sk);
2411                 free_htab_elem(htab, l_old);
2412         }
2413         raw_spin_unlock_bh(&b->lock);
2414         return 0;
2415 bucket_err:
2416         smap_release_sock(psock, sock);
2417         raw_spin_unlock_bh(&b->lock);
2418 err:
2419         kfree(e);
2420         return err;
2421 }
2422
2423 static int sock_hash_update_elem(struct bpf_map *map,
2424                                 void *key, void *value, u64 flags)
2425 {
2426         struct bpf_sock_ops_kern skops;
2427         u32 fd = *(u32 *)value;
2428         struct socket *socket;
2429         int err;
2430
2431         socket = sockfd_lookup(fd, &err);
2432         if (!socket)
2433                 return err;
2434
2435         skops.sk = socket->sk;
2436         if (!skops.sk) {
2437                 fput(socket->file);
2438                 return -EINVAL;
2439         }
2440
2441         lock_sock(skops.sk);
2442         preempt_disable();
2443         rcu_read_lock();
2444         err = sock_hash_ctx_update_elem(&skops, map, key, flags);
2445         rcu_read_unlock();
2446         preempt_enable();
2447         release_sock(skops.sk);
2448         fput(socket->file);
2449         return err;
2450 }
2451
2452 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
2453 {
2454         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2455         struct hlist_head *head;
2456         struct bucket *b;
2457         struct htab_elem *l;
2458         u32 hash, key_size;
2459         int ret = -ENOENT;
2460
2461         key_size = map->key_size;
2462         hash = htab_map_hash(key, key_size);
2463         b = __select_bucket(htab, hash);
2464         head = &b->head;
2465
2466         raw_spin_lock_bh(&b->lock);
2467         l = lookup_elem_raw(head, hash, key, key_size);
2468         if (l) {
2469                 struct sock *sock = l->sk;
2470                 struct smap_psock *psock;
2471
2472                 hlist_del_rcu(&l->hash_node);
2473                 psock = smap_psock_sk(sock);
2474                 /* This check handles a racing sock event that can get the
2475                  * sk_callback_lock before this case but after xchg happens
2476                  * causing the refcnt to hit zero and sock user data (psock)
2477                  * to be null and queued for garbage collection.
2478                  */
2479                 if (likely(psock)) {
2480                         smap_list_hash_remove(psock, l);
2481                         smap_release_sock(psock, sock);
2482                 }
2483                 free_htab_elem(htab, l);
2484                 ret = 0;
2485         }
2486         raw_spin_unlock_bh(&b->lock);
2487         return ret;
2488 }
2489
2490 struct sock  *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
2491 {
2492         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2493         struct hlist_head *head;
2494         struct htab_elem *l;
2495         u32 key_size, hash;
2496         struct bucket *b;
2497         struct sock *sk;
2498
2499         key_size = map->key_size;
2500         hash = htab_map_hash(key, key_size);
2501         b = __select_bucket(htab, hash);
2502         head = &b->head;
2503
2504         l = lookup_elem_raw(head, hash, key, key_size);
2505         sk = l ? l->sk : NULL;
2506         return sk;
2507 }
2508
2509 const struct bpf_map_ops sock_map_ops = {
2510         .map_alloc = sock_map_alloc,
2511         .map_free = sock_map_free,
2512         .map_lookup_elem = sock_map_lookup,
2513         .map_get_next_key = sock_map_get_next_key,
2514         .map_update_elem = sock_map_update_elem,
2515         .map_delete_elem = sock_map_delete_elem,
2516         .map_release_uref = sock_map_release,
2517         .map_check_btf = map_check_no_btf,
2518 };
2519
2520 const struct bpf_map_ops sock_hash_ops = {
2521         .map_alloc = sock_hash_alloc,
2522         .map_free = sock_hash_free,
2523         .map_lookup_elem = sock_map_lookup,
2524         .map_get_next_key = sock_hash_get_next_key,
2525         .map_update_elem = sock_hash_update_elem,
2526         .map_delete_elem = sock_hash_delete_elem,
2527         .map_release_uref = sock_map_release,
2528         .map_check_btf = map_check_no_btf,
2529 };
2530
2531 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
2532            struct bpf_map *, map, void *, key, u64, flags)
2533 {
2534         WARN_ON_ONCE(!rcu_read_lock_held());
2535         return sock_map_ctx_update_elem(bpf_sock, map, key, flags);
2536 }
2537
2538 const struct bpf_func_proto bpf_sock_map_update_proto = {
2539         .func           = bpf_sock_map_update,
2540         .gpl_only       = false,
2541         .pkt_access     = true,
2542         .ret_type       = RET_INTEGER,
2543         .arg1_type      = ARG_PTR_TO_CTX,
2544         .arg2_type      = ARG_CONST_MAP_PTR,
2545         .arg3_type      = ARG_PTR_TO_MAP_KEY,
2546         .arg4_type      = ARG_ANYTHING,
2547 };
2548
2549 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, bpf_sock,
2550            struct bpf_map *, map, void *, key, u64, flags)
2551 {
2552         WARN_ON_ONCE(!rcu_read_lock_held());
2553         return sock_hash_ctx_update_elem(bpf_sock, map, key, flags);
2554 }
2555
2556 const struct bpf_func_proto bpf_sock_hash_update_proto = {
2557         .func           = bpf_sock_hash_update,
2558         .gpl_only       = false,
2559         .pkt_access     = true,
2560         .ret_type       = RET_INTEGER,
2561         .arg1_type      = ARG_PTR_TO_CTX,
2562         .arg2_type      = ARG_CONST_MAP_PTR,
2563         .arg3_type      = ARG_PTR_TO_MAP_KEY,
2564         .arg4_type      = ARG_ANYTHING,
2565 };