Merge tag 'hyperv-fixes-signed' of git://git.kernel.org/pub/scm/linux/kernel/git...
[sfrench/cifs-2.6.git] / net / xdp / xsk.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* XDP sockets
3  *
4  * AF_XDP sockets allows a channel between XDP programs and userspace
5  * applications.
6  * Copyright(c) 2018 Intel Corporation.
7  *
8  * Author(s): Björn Töpel <bjorn.topel@intel.com>
9  *            Magnus Karlsson <magnus.karlsson@intel.com>
10  */
11
12 #define pr_fmt(fmt) "AF_XDP: %s: " fmt, __func__
13
14 #include <linux/if_xdp.h>
15 #include <linux/init.h>
16 #include <linux/sched/mm.h>
17 #include <linux/sched/signal.h>
18 #include <linux/sched/task.h>
19 #include <linux/socket.h>
20 #include <linux/file.h>
21 #include <linux/uaccess.h>
22 #include <linux/net.h>
23 #include <linux/netdevice.h>
24 #include <linux/rculist.h>
25 #include <net/xdp_sock_drv.h>
26 #include <net/xdp.h>
27
28 #include "xsk_queue.h"
29 #include "xdp_umem.h"
30 #include "xsk.h"
31
32 #define TX_BATCH_SIZE 16
33
34 static DEFINE_PER_CPU(struct list_head, xskmap_flush_list);
35
36 void xsk_set_rx_need_wakeup(struct xsk_buff_pool *pool)
37 {
38         if (pool->cached_need_wakeup & XDP_WAKEUP_RX)
39                 return;
40
41         pool->fq->ring->flags |= XDP_RING_NEED_WAKEUP;
42         pool->cached_need_wakeup |= XDP_WAKEUP_RX;
43 }
44 EXPORT_SYMBOL(xsk_set_rx_need_wakeup);
45
46 void xsk_set_tx_need_wakeup(struct xsk_buff_pool *pool)
47 {
48         struct xdp_sock *xs;
49
50         if (pool->cached_need_wakeup & XDP_WAKEUP_TX)
51                 return;
52
53         rcu_read_lock();
54         list_for_each_entry_rcu(xs, &pool->xsk_tx_list, tx_list) {
55                 xs->tx->ring->flags |= XDP_RING_NEED_WAKEUP;
56         }
57         rcu_read_unlock();
58
59         pool->cached_need_wakeup |= XDP_WAKEUP_TX;
60 }
61 EXPORT_SYMBOL(xsk_set_tx_need_wakeup);
62
63 void xsk_clear_rx_need_wakeup(struct xsk_buff_pool *pool)
64 {
65         if (!(pool->cached_need_wakeup & XDP_WAKEUP_RX))
66                 return;
67
68         pool->fq->ring->flags &= ~XDP_RING_NEED_WAKEUP;
69         pool->cached_need_wakeup &= ~XDP_WAKEUP_RX;
70 }
71 EXPORT_SYMBOL(xsk_clear_rx_need_wakeup);
72
73 void xsk_clear_tx_need_wakeup(struct xsk_buff_pool *pool)
74 {
75         struct xdp_sock *xs;
76
77         if (!(pool->cached_need_wakeup & XDP_WAKEUP_TX))
78                 return;
79
80         rcu_read_lock();
81         list_for_each_entry_rcu(xs, &pool->xsk_tx_list, tx_list) {
82                 xs->tx->ring->flags &= ~XDP_RING_NEED_WAKEUP;
83         }
84         rcu_read_unlock();
85
86         pool->cached_need_wakeup &= ~XDP_WAKEUP_TX;
87 }
88 EXPORT_SYMBOL(xsk_clear_tx_need_wakeup);
89
90 bool xsk_uses_need_wakeup(struct xsk_buff_pool *pool)
91 {
92         return pool->uses_need_wakeup;
93 }
94 EXPORT_SYMBOL(xsk_uses_need_wakeup);
95
96 struct xsk_buff_pool *xsk_get_pool_from_qid(struct net_device *dev,
97                                             u16 queue_id)
98 {
99         if (queue_id < dev->real_num_rx_queues)
100                 return dev->_rx[queue_id].pool;
101         if (queue_id < dev->real_num_tx_queues)
102                 return dev->_tx[queue_id].pool;
103
104         return NULL;
105 }
106 EXPORT_SYMBOL(xsk_get_pool_from_qid);
107
108 void xsk_clear_pool_at_qid(struct net_device *dev, u16 queue_id)
109 {
110         if (queue_id < dev->real_num_rx_queues)
111                 dev->_rx[queue_id].pool = NULL;
112         if (queue_id < dev->real_num_tx_queues)
113                 dev->_tx[queue_id].pool = NULL;
114 }
115
116 /* The buffer pool is stored both in the _rx struct and the _tx struct as we do
117  * not know if the device has more tx queues than rx, or the opposite.
118  * This might also change during run time.
119  */
120 int xsk_reg_pool_at_qid(struct net_device *dev, struct xsk_buff_pool *pool,
121                         u16 queue_id)
122 {
123         if (queue_id >= max_t(unsigned int,
124                               dev->real_num_rx_queues,
125                               dev->real_num_tx_queues))
126                 return -EINVAL;
127
128         if (queue_id < dev->real_num_rx_queues)
129                 dev->_rx[queue_id].pool = pool;
130         if (queue_id < dev->real_num_tx_queues)
131                 dev->_tx[queue_id].pool = pool;
132
133         return 0;
134 }
135
136 void xp_release(struct xdp_buff_xsk *xskb)
137 {
138         xskb->pool->free_heads[xskb->pool->free_heads_cnt++] = xskb;
139 }
140
141 static u64 xp_get_handle(struct xdp_buff_xsk *xskb)
142 {
143         u64 offset = xskb->xdp.data - xskb->xdp.data_hard_start;
144
145         offset += xskb->pool->headroom;
146         if (!xskb->pool->unaligned)
147                 return xskb->orig_addr + offset;
148         return xskb->orig_addr + (offset << XSK_UNALIGNED_BUF_OFFSET_SHIFT);
149 }
150
151 static int __xsk_rcv_zc(struct xdp_sock *xs, struct xdp_buff *xdp, u32 len)
152 {
153         struct xdp_buff_xsk *xskb = container_of(xdp, struct xdp_buff_xsk, xdp);
154         u64 addr;
155         int err;
156
157         addr = xp_get_handle(xskb);
158         err = xskq_prod_reserve_desc(xs->rx, addr, len);
159         if (err) {
160                 xs->rx_queue_full++;
161                 return err;
162         }
163
164         xp_release(xskb);
165         return 0;
166 }
167
168 static void xsk_copy_xdp(struct xdp_buff *to, struct xdp_buff *from, u32 len)
169 {
170         void *from_buf, *to_buf;
171         u32 metalen;
172
173         if (unlikely(xdp_data_meta_unsupported(from))) {
174                 from_buf = from->data;
175                 to_buf = to->data;
176                 metalen = 0;
177         } else {
178                 from_buf = from->data_meta;
179                 metalen = from->data - from->data_meta;
180                 to_buf = to->data - metalen;
181         }
182
183         memcpy(to_buf, from_buf, len + metalen);
184 }
185
186 static int __xsk_rcv(struct xdp_sock *xs, struct xdp_buff *xdp, u32 len,
187                      bool explicit_free)
188 {
189         struct xdp_buff *xsk_xdp;
190         int err;
191
192         if (len > xsk_pool_get_rx_frame_size(xs->pool)) {
193                 xs->rx_dropped++;
194                 return -ENOSPC;
195         }
196
197         xsk_xdp = xsk_buff_alloc(xs->pool);
198         if (!xsk_xdp) {
199                 xs->rx_dropped++;
200                 return -ENOSPC;
201         }
202
203         xsk_copy_xdp(xsk_xdp, xdp, len);
204         err = __xsk_rcv_zc(xs, xsk_xdp, len);
205         if (err) {
206                 xsk_buff_free(xsk_xdp);
207                 return err;
208         }
209         if (explicit_free)
210                 xdp_return_buff(xdp);
211         return 0;
212 }
213
214 static bool xsk_is_bound(struct xdp_sock *xs)
215 {
216         if (READ_ONCE(xs->state) == XSK_BOUND) {
217                 /* Matches smp_wmb() in bind(). */
218                 smp_rmb();
219                 return true;
220         }
221         return false;
222 }
223
224 static int xsk_rcv(struct xdp_sock *xs, struct xdp_buff *xdp,
225                    bool explicit_free)
226 {
227         u32 len;
228
229         if (!xsk_is_bound(xs))
230                 return -EINVAL;
231
232         if (xs->dev != xdp->rxq->dev || xs->queue_id != xdp->rxq->queue_index)
233                 return -EINVAL;
234
235         len = xdp->data_end - xdp->data;
236
237         return xdp->rxq->mem.type == MEM_TYPE_XSK_BUFF_POOL ?
238                 __xsk_rcv_zc(xs, xdp, len) :
239                 __xsk_rcv(xs, xdp, len, explicit_free);
240 }
241
242 static void xsk_flush(struct xdp_sock *xs)
243 {
244         xskq_prod_submit(xs->rx);
245         __xskq_cons_release(xs->pool->fq);
246         sock_def_readable(&xs->sk);
247 }
248
249 int xsk_generic_rcv(struct xdp_sock *xs, struct xdp_buff *xdp)
250 {
251         int err;
252
253         spin_lock_bh(&xs->rx_lock);
254         err = xsk_rcv(xs, xdp, false);
255         xsk_flush(xs);
256         spin_unlock_bh(&xs->rx_lock);
257         return err;
258 }
259
260 int __xsk_map_redirect(struct xdp_sock *xs, struct xdp_buff *xdp)
261 {
262         struct list_head *flush_list = this_cpu_ptr(&xskmap_flush_list);
263         int err;
264
265         err = xsk_rcv(xs, xdp, true);
266         if (err)
267                 return err;
268
269         if (!xs->flush_node.prev)
270                 list_add(&xs->flush_node, flush_list);
271
272         return 0;
273 }
274
275 void __xsk_map_flush(void)
276 {
277         struct list_head *flush_list = this_cpu_ptr(&xskmap_flush_list);
278         struct xdp_sock *xs, *tmp;
279
280         list_for_each_entry_safe(xs, tmp, flush_list, flush_node) {
281                 xsk_flush(xs);
282                 __list_del_clearprev(&xs->flush_node);
283         }
284 }
285
286 void xsk_tx_completed(struct xsk_buff_pool *pool, u32 nb_entries)
287 {
288         xskq_prod_submit_n(pool->cq, nb_entries);
289 }
290 EXPORT_SYMBOL(xsk_tx_completed);
291
292 void xsk_tx_release(struct xsk_buff_pool *pool)
293 {
294         struct xdp_sock *xs;
295
296         rcu_read_lock();
297         list_for_each_entry_rcu(xs, &pool->xsk_tx_list, tx_list) {
298                 __xskq_cons_release(xs->tx);
299                 xs->sk.sk_write_space(&xs->sk);
300         }
301         rcu_read_unlock();
302 }
303 EXPORT_SYMBOL(xsk_tx_release);
304
305 bool xsk_tx_peek_desc(struct xsk_buff_pool *pool, struct xdp_desc *desc)
306 {
307         struct xdp_sock *xs;
308
309         rcu_read_lock();
310         list_for_each_entry_rcu(xs, &pool->xsk_tx_list, tx_list) {
311                 if (!xskq_cons_peek_desc(xs->tx, desc, pool)) {
312                         xs->tx->queue_empty_descs++;
313                         continue;
314                 }
315
316                 /* This is the backpressure mechanism for the Tx path.
317                  * Reserve space in the completion queue and only proceed
318                  * if there is space in it. This avoids having to implement
319                  * any buffering in the Tx path.
320                  */
321                 if (xskq_prod_reserve_addr(pool->cq, desc->addr))
322                         goto out;
323
324                 xskq_cons_release(xs->tx);
325                 rcu_read_unlock();
326                 return true;
327         }
328
329 out:
330         rcu_read_unlock();
331         return false;
332 }
333 EXPORT_SYMBOL(xsk_tx_peek_desc);
334
335 static int xsk_wakeup(struct xdp_sock *xs, u8 flags)
336 {
337         struct net_device *dev = xs->dev;
338         int err;
339
340         rcu_read_lock();
341         err = dev->netdev_ops->ndo_xsk_wakeup(dev, xs->queue_id, flags);
342         rcu_read_unlock();
343
344         return err;
345 }
346
347 static int xsk_zc_xmit(struct xdp_sock *xs)
348 {
349         return xsk_wakeup(xs, XDP_WAKEUP_TX);
350 }
351
352 static void xsk_destruct_skb(struct sk_buff *skb)
353 {
354         u64 addr = (u64)(long)skb_shinfo(skb)->destructor_arg;
355         struct xdp_sock *xs = xdp_sk(skb->sk);
356         unsigned long flags;
357
358         spin_lock_irqsave(&xs->tx_completion_lock, flags);
359         xskq_prod_submit_addr(xs->pool->cq, addr);
360         spin_unlock_irqrestore(&xs->tx_completion_lock, flags);
361
362         sock_wfree(skb);
363 }
364
365 static int xsk_generic_xmit(struct sock *sk)
366 {
367         struct xdp_sock *xs = xdp_sk(sk);
368         u32 max_batch = TX_BATCH_SIZE;
369         bool sent_frame = false;
370         struct xdp_desc desc;
371         struct sk_buff *skb;
372         int err = 0;
373
374         mutex_lock(&xs->mutex);
375
376         if (xs->queue_id >= xs->dev->real_num_tx_queues)
377                 goto out;
378
379         while (xskq_cons_peek_desc(xs->tx, &desc, xs->pool)) {
380                 char *buffer;
381                 u64 addr;
382                 u32 len;
383
384                 if (max_batch-- == 0) {
385                         err = -EAGAIN;
386                         goto out;
387                 }
388
389                 len = desc.len;
390                 skb = sock_alloc_send_skb(sk, len, 1, &err);
391                 if (unlikely(!skb))
392                         goto out;
393
394                 skb_put(skb, len);
395                 addr = desc.addr;
396                 buffer = xsk_buff_raw_get_data(xs->pool, addr);
397                 err = skb_store_bits(skb, 0, buffer, len);
398                 /* This is the backpressure mechanism for the Tx path.
399                  * Reserve space in the completion queue and only proceed
400                  * if there is space in it. This avoids having to implement
401                  * any buffering in the Tx path.
402                  */
403                 if (unlikely(err) || xskq_prod_reserve(xs->pool->cq)) {
404                         kfree_skb(skb);
405                         goto out;
406                 }
407
408                 skb->dev = xs->dev;
409                 skb->priority = sk->sk_priority;
410                 skb->mark = sk->sk_mark;
411                 skb_shinfo(skb)->destructor_arg = (void *)(long)desc.addr;
412                 skb->destructor = xsk_destruct_skb;
413
414                 /* Hinder dev_direct_xmit from freeing the packet and
415                  * therefore completing it in the destructor
416                  */
417                 refcount_inc(&skb->users);
418                 err = dev_direct_xmit(skb, xs->queue_id);
419                 if  (err == NETDEV_TX_BUSY) {
420                         /* Tell user-space to retry the send */
421                         skb->destructor = sock_wfree;
422                         /* Free skb without triggering the perf drop trace */
423                         consume_skb(skb);
424                         err = -EAGAIN;
425                         goto out;
426                 }
427
428                 xskq_cons_release(xs->tx);
429                 /* Ignore NET_XMIT_CN as packet might have been sent */
430                 if (err == NET_XMIT_DROP) {
431                         /* SKB completed but not sent */
432                         kfree_skb(skb);
433                         err = -EBUSY;
434                         goto out;
435                 }
436
437                 consume_skb(skb);
438                 sent_frame = true;
439         }
440
441         xs->tx->queue_empty_descs++;
442
443 out:
444         if (sent_frame)
445                 sk->sk_write_space(sk);
446
447         mutex_unlock(&xs->mutex);
448         return err;
449 }
450
451 static int __xsk_sendmsg(struct sock *sk)
452 {
453         struct xdp_sock *xs = xdp_sk(sk);
454
455         if (unlikely(!(xs->dev->flags & IFF_UP)))
456                 return -ENETDOWN;
457         if (unlikely(!xs->tx))
458                 return -ENOBUFS;
459
460         return xs->zc ? xsk_zc_xmit(xs) : xsk_generic_xmit(sk);
461 }
462
463 static int xsk_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len)
464 {
465         bool need_wait = !(m->msg_flags & MSG_DONTWAIT);
466         struct sock *sk = sock->sk;
467         struct xdp_sock *xs = xdp_sk(sk);
468
469         if (unlikely(!xsk_is_bound(xs)))
470                 return -ENXIO;
471         if (unlikely(need_wait))
472                 return -EOPNOTSUPP;
473
474         return __xsk_sendmsg(sk);
475 }
476
477 static __poll_t xsk_poll(struct file *file, struct socket *sock,
478                              struct poll_table_struct *wait)
479 {
480         __poll_t mask = datagram_poll(file, sock, wait);
481         struct sock *sk = sock->sk;
482         struct xdp_sock *xs = xdp_sk(sk);
483         struct xsk_buff_pool *pool;
484
485         if (unlikely(!xsk_is_bound(xs)))
486                 return mask;
487
488         pool = xs->pool;
489
490         if (pool->cached_need_wakeup) {
491                 if (xs->zc)
492                         xsk_wakeup(xs, pool->cached_need_wakeup);
493                 else
494                         /* Poll needs to drive Tx also in copy mode */
495                         __xsk_sendmsg(sk);
496         }
497
498         if (xs->rx && !xskq_prod_is_empty(xs->rx))
499                 mask |= EPOLLIN | EPOLLRDNORM;
500         if (xs->tx && !xskq_cons_is_full(xs->tx))
501                 mask |= EPOLLOUT | EPOLLWRNORM;
502
503         return mask;
504 }
505
506 static int xsk_init_queue(u32 entries, struct xsk_queue **queue,
507                           bool umem_queue)
508 {
509         struct xsk_queue *q;
510
511         if (entries == 0 || *queue || !is_power_of_2(entries))
512                 return -EINVAL;
513
514         q = xskq_create(entries, umem_queue);
515         if (!q)
516                 return -ENOMEM;
517
518         /* Make sure queue is ready before it can be seen by others */
519         smp_wmb();
520         WRITE_ONCE(*queue, q);
521         return 0;
522 }
523
524 static void xsk_unbind_dev(struct xdp_sock *xs)
525 {
526         struct net_device *dev = xs->dev;
527
528         if (xs->state != XSK_BOUND)
529                 return;
530         WRITE_ONCE(xs->state, XSK_UNBOUND);
531
532         /* Wait for driver to stop using the xdp socket. */
533         xp_del_xsk(xs->pool, xs);
534         xs->dev = NULL;
535         synchronize_net();
536         dev_put(dev);
537 }
538
539 static struct xsk_map *xsk_get_map_list_entry(struct xdp_sock *xs,
540                                               struct xdp_sock ***map_entry)
541 {
542         struct xsk_map *map = NULL;
543         struct xsk_map_node *node;
544
545         *map_entry = NULL;
546
547         spin_lock_bh(&xs->map_list_lock);
548         node = list_first_entry_or_null(&xs->map_list, struct xsk_map_node,
549                                         node);
550         if (node) {
551                 WARN_ON(xsk_map_inc(node->map));
552                 map = node->map;
553                 *map_entry = node->map_entry;
554         }
555         spin_unlock_bh(&xs->map_list_lock);
556         return map;
557 }
558
559 static void xsk_delete_from_maps(struct xdp_sock *xs)
560 {
561         /* This function removes the current XDP socket from all the
562          * maps it resides in. We need to take extra care here, due to
563          * the two locks involved. Each map has a lock synchronizing
564          * updates to the entries, and each socket has a lock that
565          * synchronizes access to the list of maps (map_list). For
566          * deadlock avoidance the locks need to be taken in the order
567          * "map lock"->"socket map list lock". We start off by
568          * accessing the socket map list, and take a reference to the
569          * map to guarantee existence between the
570          * xsk_get_map_list_entry() and xsk_map_try_sock_delete()
571          * calls. Then we ask the map to remove the socket, which
572          * tries to remove the socket from the map. Note that there
573          * might be updates to the map between
574          * xsk_get_map_list_entry() and xsk_map_try_sock_delete().
575          */
576         struct xdp_sock **map_entry = NULL;
577         struct xsk_map *map;
578
579         while ((map = xsk_get_map_list_entry(xs, &map_entry))) {
580                 xsk_map_try_sock_delete(map, xs, map_entry);
581                 xsk_map_put(map);
582         }
583 }
584
585 static int xsk_release(struct socket *sock)
586 {
587         struct sock *sk = sock->sk;
588         struct xdp_sock *xs = xdp_sk(sk);
589         struct net *net;
590
591         if (!sk)
592                 return 0;
593
594         net = sock_net(sk);
595
596         mutex_lock(&net->xdp.lock);
597         sk_del_node_init_rcu(sk);
598         mutex_unlock(&net->xdp.lock);
599
600         local_bh_disable();
601         sock_prot_inuse_add(net, sk->sk_prot, -1);
602         local_bh_enable();
603
604         xsk_delete_from_maps(xs);
605         mutex_lock(&xs->mutex);
606         xsk_unbind_dev(xs);
607         mutex_unlock(&xs->mutex);
608
609         xskq_destroy(xs->rx);
610         xskq_destroy(xs->tx);
611         xskq_destroy(xs->fq_tmp);
612         xskq_destroy(xs->cq_tmp);
613
614         sock_orphan(sk);
615         sock->sk = NULL;
616
617         sk_refcnt_debug_release(sk);
618         sock_put(sk);
619
620         return 0;
621 }
622
623 static struct socket *xsk_lookup_xsk_from_fd(int fd)
624 {
625         struct socket *sock;
626         int err;
627
628         sock = sockfd_lookup(fd, &err);
629         if (!sock)
630                 return ERR_PTR(-ENOTSOCK);
631
632         if (sock->sk->sk_family != PF_XDP) {
633                 sockfd_put(sock);
634                 return ERR_PTR(-ENOPROTOOPT);
635         }
636
637         return sock;
638 }
639
640 static bool xsk_validate_queues(struct xdp_sock *xs)
641 {
642         return xs->fq_tmp && xs->cq_tmp;
643 }
644
645 static int xsk_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
646 {
647         struct sockaddr_xdp *sxdp = (struct sockaddr_xdp *)addr;
648         struct sock *sk = sock->sk;
649         struct xdp_sock *xs = xdp_sk(sk);
650         struct net_device *dev;
651         u32 flags, qid;
652         int err = 0;
653
654         if (addr_len < sizeof(struct sockaddr_xdp))
655                 return -EINVAL;
656         if (sxdp->sxdp_family != AF_XDP)
657                 return -EINVAL;
658
659         flags = sxdp->sxdp_flags;
660         if (flags & ~(XDP_SHARED_UMEM | XDP_COPY | XDP_ZEROCOPY |
661                       XDP_USE_NEED_WAKEUP))
662                 return -EINVAL;
663
664         rtnl_lock();
665         mutex_lock(&xs->mutex);
666         if (xs->state != XSK_READY) {
667                 err = -EBUSY;
668                 goto out_release;
669         }
670
671         dev = dev_get_by_index(sock_net(sk), sxdp->sxdp_ifindex);
672         if (!dev) {
673                 err = -ENODEV;
674                 goto out_release;
675         }
676
677         if (!xs->rx && !xs->tx) {
678                 err = -EINVAL;
679                 goto out_unlock;
680         }
681
682         qid = sxdp->sxdp_queue_id;
683
684         if (flags & XDP_SHARED_UMEM) {
685                 struct xdp_sock *umem_xs;
686                 struct socket *sock;
687
688                 if ((flags & XDP_COPY) || (flags & XDP_ZEROCOPY) ||
689                     (flags & XDP_USE_NEED_WAKEUP)) {
690                         /* Cannot specify flags for shared sockets. */
691                         err = -EINVAL;
692                         goto out_unlock;
693                 }
694
695                 if (xs->umem) {
696                         /* We have already our own. */
697                         err = -EINVAL;
698                         goto out_unlock;
699                 }
700
701                 sock = xsk_lookup_xsk_from_fd(sxdp->sxdp_shared_umem_fd);
702                 if (IS_ERR(sock)) {
703                         err = PTR_ERR(sock);
704                         goto out_unlock;
705                 }
706
707                 umem_xs = xdp_sk(sock->sk);
708                 if (!xsk_is_bound(umem_xs)) {
709                         err = -EBADF;
710                         sockfd_put(sock);
711                         goto out_unlock;
712                 }
713
714                 if (umem_xs->queue_id != qid || umem_xs->dev != dev) {
715                         /* Share the umem with another socket on another qid
716                          * and/or device.
717                          */
718                         xs->pool = xp_create_and_assign_umem(xs,
719                                                              umem_xs->umem);
720                         if (!xs->pool) {
721                                 err = -ENOMEM;
722                                 sockfd_put(sock);
723                                 goto out_unlock;
724                         }
725
726                         err = xp_assign_dev_shared(xs->pool, umem_xs->umem,
727                                                    dev, qid);
728                         if (err) {
729                                 xp_destroy(xs->pool);
730                                 xs->pool = NULL;
731                                 sockfd_put(sock);
732                                 goto out_unlock;
733                         }
734                 } else {
735                         /* Share the buffer pool with the other socket. */
736                         if (xs->fq_tmp || xs->cq_tmp) {
737                                 /* Do not allow setting your own fq or cq. */
738                                 err = -EINVAL;
739                                 sockfd_put(sock);
740                                 goto out_unlock;
741                         }
742
743                         xp_get_pool(umem_xs->pool);
744                         xs->pool = umem_xs->pool;
745                 }
746
747                 xdp_get_umem(umem_xs->umem);
748                 WRITE_ONCE(xs->umem, umem_xs->umem);
749                 sockfd_put(sock);
750         } else if (!xs->umem || !xsk_validate_queues(xs)) {
751                 err = -EINVAL;
752                 goto out_unlock;
753         } else {
754                 /* This xsk has its own umem. */
755                 xs->pool = xp_create_and_assign_umem(xs, xs->umem);
756                 if (!xs->pool) {
757                         err = -ENOMEM;
758                         goto out_unlock;
759                 }
760
761                 err = xp_assign_dev(xs->pool, dev, qid, flags);
762                 if (err) {
763                         xp_destroy(xs->pool);
764                         xs->pool = NULL;
765                         goto out_unlock;
766                 }
767         }
768
769         xs->dev = dev;
770         xs->zc = xs->umem->zc;
771         xs->queue_id = qid;
772         xp_add_xsk(xs->pool, xs);
773
774 out_unlock:
775         if (err) {
776                 dev_put(dev);
777         } else {
778                 /* Matches smp_rmb() in bind() for shared umem
779                  * sockets, and xsk_is_bound().
780                  */
781                 smp_wmb();
782                 WRITE_ONCE(xs->state, XSK_BOUND);
783         }
784 out_release:
785         mutex_unlock(&xs->mutex);
786         rtnl_unlock();
787         return err;
788 }
789
790 struct xdp_umem_reg_v1 {
791         __u64 addr; /* Start of packet data area */
792         __u64 len; /* Length of packet data area */
793         __u32 chunk_size;
794         __u32 headroom;
795 };
796
797 static int xsk_setsockopt(struct socket *sock, int level, int optname,
798                           sockptr_t optval, unsigned int optlen)
799 {
800         struct sock *sk = sock->sk;
801         struct xdp_sock *xs = xdp_sk(sk);
802         int err;
803
804         if (level != SOL_XDP)
805                 return -ENOPROTOOPT;
806
807         switch (optname) {
808         case XDP_RX_RING:
809         case XDP_TX_RING:
810         {
811                 struct xsk_queue **q;
812                 int entries;
813
814                 if (optlen < sizeof(entries))
815                         return -EINVAL;
816                 if (copy_from_sockptr(&entries, optval, sizeof(entries)))
817                         return -EFAULT;
818
819                 mutex_lock(&xs->mutex);
820                 if (xs->state != XSK_READY) {
821                         mutex_unlock(&xs->mutex);
822                         return -EBUSY;
823                 }
824                 q = (optname == XDP_TX_RING) ? &xs->tx : &xs->rx;
825                 err = xsk_init_queue(entries, q, false);
826                 if (!err && optname == XDP_TX_RING)
827                         /* Tx needs to be explicitly woken up the first time */
828                         xs->tx->ring->flags |= XDP_RING_NEED_WAKEUP;
829                 mutex_unlock(&xs->mutex);
830                 return err;
831         }
832         case XDP_UMEM_REG:
833         {
834                 size_t mr_size = sizeof(struct xdp_umem_reg);
835                 struct xdp_umem_reg mr = {};
836                 struct xdp_umem *umem;
837
838                 if (optlen < sizeof(struct xdp_umem_reg_v1))
839                         return -EINVAL;
840                 else if (optlen < sizeof(mr))
841                         mr_size = sizeof(struct xdp_umem_reg_v1);
842
843                 if (copy_from_sockptr(&mr, optval, mr_size))
844                         return -EFAULT;
845
846                 mutex_lock(&xs->mutex);
847                 if (xs->state != XSK_READY || xs->umem) {
848                         mutex_unlock(&xs->mutex);
849                         return -EBUSY;
850                 }
851
852                 umem = xdp_umem_create(&mr);
853                 if (IS_ERR(umem)) {
854                         mutex_unlock(&xs->mutex);
855                         return PTR_ERR(umem);
856                 }
857
858                 /* Make sure umem is ready before it can be seen by others */
859                 smp_wmb();
860                 WRITE_ONCE(xs->umem, umem);
861                 mutex_unlock(&xs->mutex);
862                 return 0;
863         }
864         case XDP_UMEM_FILL_RING:
865         case XDP_UMEM_COMPLETION_RING:
866         {
867                 struct xsk_queue **q;
868                 int entries;
869
870                 if (copy_from_sockptr(&entries, optval, sizeof(entries)))
871                         return -EFAULT;
872
873                 mutex_lock(&xs->mutex);
874                 if (xs->state != XSK_READY) {
875                         mutex_unlock(&xs->mutex);
876                         return -EBUSY;
877                 }
878
879                 q = (optname == XDP_UMEM_FILL_RING) ? &xs->fq_tmp :
880                         &xs->cq_tmp;
881                 err = xsk_init_queue(entries, q, true);
882                 mutex_unlock(&xs->mutex);
883                 return err;
884         }
885         default:
886                 break;
887         }
888
889         return -ENOPROTOOPT;
890 }
891
892 static void xsk_enter_rxtx_offsets(struct xdp_ring_offset_v1 *ring)
893 {
894         ring->producer = offsetof(struct xdp_rxtx_ring, ptrs.producer);
895         ring->consumer = offsetof(struct xdp_rxtx_ring, ptrs.consumer);
896         ring->desc = offsetof(struct xdp_rxtx_ring, desc);
897 }
898
899 static void xsk_enter_umem_offsets(struct xdp_ring_offset_v1 *ring)
900 {
901         ring->producer = offsetof(struct xdp_umem_ring, ptrs.producer);
902         ring->consumer = offsetof(struct xdp_umem_ring, ptrs.consumer);
903         ring->desc = offsetof(struct xdp_umem_ring, desc);
904 }
905
906 struct xdp_statistics_v1 {
907         __u64 rx_dropped;
908         __u64 rx_invalid_descs;
909         __u64 tx_invalid_descs;
910 };
911
912 static int xsk_getsockopt(struct socket *sock, int level, int optname,
913                           char __user *optval, int __user *optlen)
914 {
915         struct sock *sk = sock->sk;
916         struct xdp_sock *xs = xdp_sk(sk);
917         int len;
918
919         if (level != SOL_XDP)
920                 return -ENOPROTOOPT;
921
922         if (get_user(len, optlen))
923                 return -EFAULT;
924         if (len < 0)
925                 return -EINVAL;
926
927         switch (optname) {
928         case XDP_STATISTICS:
929         {
930                 struct xdp_statistics stats = {};
931                 bool extra_stats = true;
932                 size_t stats_size;
933
934                 if (len < sizeof(struct xdp_statistics_v1)) {
935                         return -EINVAL;
936                 } else if (len < sizeof(stats)) {
937                         extra_stats = false;
938                         stats_size = sizeof(struct xdp_statistics_v1);
939                 } else {
940                         stats_size = sizeof(stats);
941                 }
942
943                 mutex_lock(&xs->mutex);
944                 stats.rx_dropped = xs->rx_dropped;
945                 if (extra_stats) {
946                         stats.rx_ring_full = xs->rx_queue_full;
947                         stats.rx_fill_ring_empty_descs =
948                                 xs->pool ? xskq_nb_queue_empty_descs(xs->pool->fq) : 0;
949                         stats.tx_ring_empty_descs = xskq_nb_queue_empty_descs(xs->tx);
950                 } else {
951                         stats.rx_dropped += xs->rx_queue_full;
952                 }
953                 stats.rx_invalid_descs = xskq_nb_invalid_descs(xs->rx);
954                 stats.tx_invalid_descs = xskq_nb_invalid_descs(xs->tx);
955                 mutex_unlock(&xs->mutex);
956
957                 if (copy_to_user(optval, &stats, stats_size))
958                         return -EFAULT;
959                 if (put_user(stats_size, optlen))
960                         return -EFAULT;
961
962                 return 0;
963         }
964         case XDP_MMAP_OFFSETS:
965         {
966                 struct xdp_mmap_offsets off;
967                 struct xdp_mmap_offsets_v1 off_v1;
968                 bool flags_supported = true;
969                 void *to_copy;
970
971                 if (len < sizeof(off_v1))
972                         return -EINVAL;
973                 else if (len < sizeof(off))
974                         flags_supported = false;
975
976                 if (flags_supported) {
977                         /* xdp_ring_offset is identical to xdp_ring_offset_v1
978                          * except for the flags field added to the end.
979                          */
980                         xsk_enter_rxtx_offsets((struct xdp_ring_offset_v1 *)
981                                                &off.rx);
982                         xsk_enter_rxtx_offsets((struct xdp_ring_offset_v1 *)
983                                                &off.tx);
984                         xsk_enter_umem_offsets((struct xdp_ring_offset_v1 *)
985                                                &off.fr);
986                         xsk_enter_umem_offsets((struct xdp_ring_offset_v1 *)
987                                                &off.cr);
988                         off.rx.flags = offsetof(struct xdp_rxtx_ring,
989                                                 ptrs.flags);
990                         off.tx.flags = offsetof(struct xdp_rxtx_ring,
991                                                 ptrs.flags);
992                         off.fr.flags = offsetof(struct xdp_umem_ring,
993                                                 ptrs.flags);
994                         off.cr.flags = offsetof(struct xdp_umem_ring,
995                                                 ptrs.flags);
996
997                         len = sizeof(off);
998                         to_copy = &off;
999                 } else {
1000                         xsk_enter_rxtx_offsets(&off_v1.rx);
1001                         xsk_enter_rxtx_offsets(&off_v1.tx);
1002                         xsk_enter_umem_offsets(&off_v1.fr);
1003                         xsk_enter_umem_offsets(&off_v1.cr);
1004
1005                         len = sizeof(off_v1);
1006                         to_copy = &off_v1;
1007                 }
1008
1009                 if (copy_to_user(optval, to_copy, len))
1010                         return -EFAULT;
1011                 if (put_user(len, optlen))
1012                         return -EFAULT;
1013
1014                 return 0;
1015         }
1016         case XDP_OPTIONS:
1017         {
1018                 struct xdp_options opts = {};
1019
1020                 if (len < sizeof(opts))
1021                         return -EINVAL;
1022
1023                 mutex_lock(&xs->mutex);
1024                 if (xs->zc)
1025                         opts.flags |= XDP_OPTIONS_ZEROCOPY;
1026                 mutex_unlock(&xs->mutex);
1027
1028                 len = sizeof(opts);
1029                 if (copy_to_user(optval, &opts, len))
1030                         return -EFAULT;
1031                 if (put_user(len, optlen))
1032                         return -EFAULT;
1033
1034                 return 0;
1035         }
1036         default:
1037                 break;
1038         }
1039
1040         return -EOPNOTSUPP;
1041 }
1042
1043 static int xsk_mmap(struct file *file, struct socket *sock,
1044                     struct vm_area_struct *vma)
1045 {
1046         loff_t offset = (loff_t)vma->vm_pgoff << PAGE_SHIFT;
1047         unsigned long size = vma->vm_end - vma->vm_start;
1048         struct xdp_sock *xs = xdp_sk(sock->sk);
1049         struct xsk_queue *q = NULL;
1050         unsigned long pfn;
1051         struct page *qpg;
1052
1053         if (READ_ONCE(xs->state) != XSK_READY)
1054                 return -EBUSY;
1055
1056         if (offset == XDP_PGOFF_RX_RING) {
1057                 q = READ_ONCE(xs->rx);
1058         } else if (offset == XDP_PGOFF_TX_RING) {
1059                 q = READ_ONCE(xs->tx);
1060         } else {
1061                 /* Matches the smp_wmb() in XDP_UMEM_REG */
1062                 smp_rmb();
1063                 if (offset == XDP_UMEM_PGOFF_FILL_RING)
1064                         q = READ_ONCE(xs->fq_tmp);
1065                 else if (offset == XDP_UMEM_PGOFF_COMPLETION_RING)
1066                         q = READ_ONCE(xs->cq_tmp);
1067         }
1068
1069         if (!q)
1070                 return -EINVAL;
1071
1072         /* Matches the smp_wmb() in xsk_init_queue */
1073         smp_rmb();
1074         qpg = virt_to_head_page(q->ring);
1075         if (size > page_size(qpg))
1076                 return -EINVAL;
1077
1078         pfn = virt_to_phys(q->ring) >> PAGE_SHIFT;
1079         return remap_pfn_range(vma, vma->vm_start, pfn,
1080                                size, vma->vm_page_prot);
1081 }
1082
1083 static int xsk_notifier(struct notifier_block *this,
1084                         unsigned long msg, void *ptr)
1085 {
1086         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1087         struct net *net = dev_net(dev);
1088         struct sock *sk;
1089
1090         switch (msg) {
1091         case NETDEV_UNREGISTER:
1092                 mutex_lock(&net->xdp.lock);
1093                 sk_for_each(sk, &net->xdp.list) {
1094                         struct xdp_sock *xs = xdp_sk(sk);
1095
1096                         mutex_lock(&xs->mutex);
1097                         if (xs->dev == dev) {
1098                                 sk->sk_err = ENETDOWN;
1099                                 if (!sock_flag(sk, SOCK_DEAD))
1100                                         sk->sk_error_report(sk);
1101
1102                                 xsk_unbind_dev(xs);
1103
1104                                 /* Clear device references. */
1105                                 xp_clear_dev(xs->pool);
1106                         }
1107                         mutex_unlock(&xs->mutex);
1108                 }
1109                 mutex_unlock(&net->xdp.lock);
1110                 break;
1111         }
1112         return NOTIFY_DONE;
1113 }
1114
1115 static struct proto xsk_proto = {
1116         .name =         "XDP",
1117         .owner =        THIS_MODULE,
1118         .obj_size =     sizeof(struct xdp_sock),
1119 };
1120
1121 static const struct proto_ops xsk_proto_ops = {
1122         .family         = PF_XDP,
1123         .owner          = THIS_MODULE,
1124         .release        = xsk_release,
1125         .bind           = xsk_bind,
1126         .connect        = sock_no_connect,
1127         .socketpair     = sock_no_socketpair,
1128         .accept         = sock_no_accept,
1129         .getname        = sock_no_getname,
1130         .poll           = xsk_poll,
1131         .ioctl          = sock_no_ioctl,
1132         .listen         = sock_no_listen,
1133         .shutdown       = sock_no_shutdown,
1134         .setsockopt     = xsk_setsockopt,
1135         .getsockopt     = xsk_getsockopt,
1136         .sendmsg        = xsk_sendmsg,
1137         .recvmsg        = sock_no_recvmsg,
1138         .mmap           = xsk_mmap,
1139         .sendpage       = sock_no_sendpage,
1140 };
1141
1142 static void xsk_destruct(struct sock *sk)
1143 {
1144         struct xdp_sock *xs = xdp_sk(sk);
1145
1146         if (!sock_flag(sk, SOCK_DEAD))
1147                 return;
1148
1149         if (!xp_put_pool(xs->pool))
1150                 xdp_put_umem(xs->umem);
1151
1152         sk_refcnt_debug_dec(sk);
1153 }
1154
1155 static int xsk_create(struct net *net, struct socket *sock, int protocol,
1156                       int kern)
1157 {
1158         struct xdp_sock *xs;
1159         struct sock *sk;
1160
1161         if (!ns_capable(net->user_ns, CAP_NET_RAW))
1162                 return -EPERM;
1163         if (sock->type != SOCK_RAW)
1164                 return -ESOCKTNOSUPPORT;
1165
1166         if (protocol)
1167                 return -EPROTONOSUPPORT;
1168
1169         sock->state = SS_UNCONNECTED;
1170
1171         sk = sk_alloc(net, PF_XDP, GFP_KERNEL, &xsk_proto, kern);
1172         if (!sk)
1173                 return -ENOBUFS;
1174
1175         sock->ops = &xsk_proto_ops;
1176
1177         sock_init_data(sock, sk);
1178
1179         sk->sk_family = PF_XDP;
1180
1181         sk->sk_destruct = xsk_destruct;
1182         sk_refcnt_debug_inc(sk);
1183
1184         sock_set_flag(sk, SOCK_RCU_FREE);
1185
1186         xs = xdp_sk(sk);
1187         xs->state = XSK_READY;
1188         mutex_init(&xs->mutex);
1189         spin_lock_init(&xs->rx_lock);
1190         spin_lock_init(&xs->tx_completion_lock);
1191
1192         INIT_LIST_HEAD(&xs->map_list);
1193         spin_lock_init(&xs->map_list_lock);
1194
1195         mutex_lock(&net->xdp.lock);
1196         sk_add_node_rcu(sk, &net->xdp.list);
1197         mutex_unlock(&net->xdp.lock);
1198
1199         local_bh_disable();
1200         sock_prot_inuse_add(net, &xsk_proto, 1);
1201         local_bh_enable();
1202
1203         return 0;
1204 }
1205
1206 static const struct net_proto_family xsk_family_ops = {
1207         .family = PF_XDP,
1208         .create = xsk_create,
1209         .owner  = THIS_MODULE,
1210 };
1211
1212 static struct notifier_block xsk_netdev_notifier = {
1213         .notifier_call  = xsk_notifier,
1214 };
1215
1216 static int __net_init xsk_net_init(struct net *net)
1217 {
1218         mutex_init(&net->xdp.lock);
1219         INIT_HLIST_HEAD(&net->xdp.list);
1220         return 0;
1221 }
1222
1223 static void __net_exit xsk_net_exit(struct net *net)
1224 {
1225         WARN_ON_ONCE(!hlist_empty(&net->xdp.list));
1226 }
1227
1228 static struct pernet_operations xsk_net_ops = {
1229         .init = xsk_net_init,
1230         .exit = xsk_net_exit,
1231 };
1232
1233 static int __init xsk_init(void)
1234 {
1235         int err, cpu;
1236
1237         err = proto_register(&xsk_proto, 0 /* no slab */);
1238         if (err)
1239                 goto out;
1240
1241         err = sock_register(&xsk_family_ops);
1242         if (err)
1243                 goto out_proto;
1244
1245         err = register_pernet_subsys(&xsk_net_ops);
1246         if (err)
1247                 goto out_sk;
1248
1249         err = register_netdevice_notifier(&xsk_netdev_notifier);
1250         if (err)
1251                 goto out_pernet;
1252
1253         for_each_possible_cpu(cpu)
1254                 INIT_LIST_HEAD(&per_cpu(xskmap_flush_list, cpu));
1255         return 0;
1256
1257 out_pernet:
1258         unregister_pernet_subsys(&xsk_net_ops);
1259 out_sk:
1260         sock_unregister(PF_XDP);
1261 out_proto:
1262         proto_unregister(&xsk_proto);
1263 out:
1264         return err;
1265 }
1266
1267 fs_initcall(xsk_init);