perf annotate: Show full source location with 'l' hotkey
[sfrench/cifs-2.6.git] / drivers / net / wireguard / send.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5
6 #include "queueing.h"
7 #include "timers.h"
8 #include "device.h"
9 #include "peer.h"
10 #include "socket.h"
11 #include "messages.h"
12 #include "cookie.h"
13
14 #include <linux/uio.h>
15 #include <linux/inetdevice.h>
16 #include <linux/socket.h>
17 #include <net/ip_tunnels.h>
18 #include <net/udp.h>
19 #include <net/sock.h>
20
21 static void wg_packet_send_handshake_initiation(struct wg_peer *peer)
22 {
23         struct message_handshake_initiation packet;
24
25         if (!wg_birthdate_has_expired(atomic64_read(&peer->last_sent_handshake),
26                                       REKEY_TIMEOUT))
27                 return; /* This function is rate limited. */
28
29         atomic64_set(&peer->last_sent_handshake, ktime_get_coarse_boottime_ns());
30         net_dbg_ratelimited("%s: Sending handshake initiation to peer %llu (%pISpfsc)\n",
31                             peer->device->dev->name, peer->internal_id,
32                             &peer->endpoint.addr);
33
34         if (wg_noise_handshake_create_initiation(&packet, &peer->handshake)) {
35                 wg_cookie_add_mac_to_packet(&packet, sizeof(packet), peer);
36                 wg_timers_any_authenticated_packet_traversal(peer);
37                 wg_timers_any_authenticated_packet_sent(peer);
38                 atomic64_set(&peer->last_sent_handshake,
39                              ktime_get_coarse_boottime_ns());
40                 wg_socket_send_buffer_to_peer(peer, &packet, sizeof(packet),
41                                               HANDSHAKE_DSCP);
42                 wg_timers_handshake_initiated(peer);
43         }
44 }
45
46 void wg_packet_handshake_send_worker(struct work_struct *work)
47 {
48         struct wg_peer *peer = container_of(work, struct wg_peer,
49                                             transmit_handshake_work);
50
51         wg_packet_send_handshake_initiation(peer);
52         wg_peer_put(peer);
53 }
54
55 void wg_packet_send_queued_handshake_initiation(struct wg_peer *peer,
56                                                 bool is_retry)
57 {
58         if (!is_retry)
59                 peer->timer_handshake_attempts = 0;
60
61         rcu_read_lock_bh();
62         /* We check last_sent_handshake here in addition to the actual function
63          * we're queueing up, so that we don't queue things if not strictly
64          * necessary:
65          */
66         if (!wg_birthdate_has_expired(atomic64_read(&peer->last_sent_handshake),
67                                       REKEY_TIMEOUT) ||
68                         unlikely(READ_ONCE(peer->is_dead)))
69                 goto out;
70
71         wg_peer_get(peer);
72         /* Queues up calling packet_send_queued_handshakes(peer), where we do a
73          * peer_put(peer) after:
74          */
75         if (!queue_work(peer->device->handshake_send_wq,
76                         &peer->transmit_handshake_work))
77                 /* If the work was already queued, we want to drop the
78                  * extra reference:
79                  */
80                 wg_peer_put(peer);
81 out:
82         rcu_read_unlock_bh();
83 }
84
85 void wg_packet_send_handshake_response(struct wg_peer *peer)
86 {
87         struct message_handshake_response packet;
88
89         atomic64_set(&peer->last_sent_handshake, ktime_get_coarse_boottime_ns());
90         net_dbg_ratelimited("%s: Sending handshake response to peer %llu (%pISpfsc)\n",
91                             peer->device->dev->name, peer->internal_id,
92                             &peer->endpoint.addr);
93
94         if (wg_noise_handshake_create_response(&packet, &peer->handshake)) {
95                 wg_cookie_add_mac_to_packet(&packet, sizeof(packet), peer);
96                 if (wg_noise_handshake_begin_session(&peer->handshake,
97                                                      &peer->keypairs)) {
98                         wg_timers_session_derived(peer);
99                         wg_timers_any_authenticated_packet_traversal(peer);
100                         wg_timers_any_authenticated_packet_sent(peer);
101                         atomic64_set(&peer->last_sent_handshake,
102                                      ktime_get_coarse_boottime_ns());
103                         wg_socket_send_buffer_to_peer(peer, &packet,
104                                                       sizeof(packet),
105                                                       HANDSHAKE_DSCP);
106                 }
107         }
108 }
109
110 void wg_packet_send_handshake_cookie(struct wg_device *wg,
111                                      struct sk_buff *initiating_skb,
112                                      __le32 sender_index)
113 {
114         struct message_handshake_cookie packet;
115
116         net_dbg_skb_ratelimited("%s: Sending cookie response for denied handshake message for %pISpfsc\n",
117                                 wg->dev->name, initiating_skb);
118         wg_cookie_message_create(&packet, initiating_skb, sender_index,
119                                  &wg->cookie_checker);
120         wg_socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet,
121                                               sizeof(packet));
122 }
123
124 static void keep_key_fresh(struct wg_peer *peer)
125 {
126         struct noise_keypair *keypair;
127         bool send;
128
129         rcu_read_lock_bh();
130         keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
131         send = keypair && READ_ONCE(keypair->sending.is_valid) &&
132                (atomic64_read(&keypair->sending_counter) > REKEY_AFTER_MESSAGES ||
133                 (keypair->i_am_the_initiator &&
134                  wg_birthdate_has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME)));
135         rcu_read_unlock_bh();
136
137         if (unlikely(send))
138                 wg_packet_send_queued_handshake_initiation(peer, false);
139 }
140
141 static unsigned int calculate_skb_padding(struct sk_buff *skb)
142 {
143         unsigned int padded_size, last_unit = skb->len;
144
145         if (unlikely(!PACKET_CB(skb)->mtu))
146                 return ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE) - last_unit;
147
148         /* We do this modulo business with the MTU, just in case the networking
149          * layer gives us a packet that's bigger than the MTU. In that case, we
150          * wouldn't want the final subtraction to overflow in the case of the
151          * padded_size being clamped. Fortunately, that's very rarely the case,
152          * so we optimize for that not happening.
153          */
154         if (unlikely(last_unit > PACKET_CB(skb)->mtu))
155                 last_unit %= PACKET_CB(skb)->mtu;
156
157         padded_size = min(PACKET_CB(skb)->mtu,
158                           ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE));
159         return padded_size - last_unit;
160 }
161
162 static bool encrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
163 {
164         unsigned int padding_len, plaintext_len, trailer_len;
165         struct scatterlist sg[MAX_SKB_FRAGS + 8];
166         struct message_data *header;
167         struct sk_buff *trailer;
168         int num_frags;
169
170         /* Force hash calculation before encryption so that flow analysis is
171          * consistent over the inner packet.
172          */
173         skb_get_hash(skb);
174
175         /* Calculate lengths. */
176         padding_len = calculate_skb_padding(skb);
177         trailer_len = padding_len + noise_encrypted_len(0);
178         plaintext_len = skb->len + padding_len;
179
180         /* Expand data section to have room for padding and auth tag. */
181         num_frags = skb_cow_data(skb, trailer_len, &trailer);
182         if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
183                 return false;
184
185         /* Set the padding to zeros, and make sure it and the auth tag are part
186          * of the skb.
187          */
188         memset(skb_tail_pointer(trailer), 0, padding_len);
189
190         /* Expand head section to have room for our header and the network
191          * stack's headers.
192          */
193         if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0))
194                 return false;
195
196         /* Finalize checksum calculation for the inner packet, if required. */
197         if (unlikely(skb->ip_summed == CHECKSUM_PARTIAL &&
198                      skb_checksum_help(skb)))
199                 return false;
200
201         /* Only after checksumming can we safely add on the padding at the end
202          * and the header.
203          */
204         skb_set_inner_network_header(skb, 0);
205         header = (struct message_data *)skb_push(skb, sizeof(*header));
206         header->header.type = cpu_to_le32(MESSAGE_DATA);
207         header->key_idx = keypair->remote_index;
208         header->counter = cpu_to_le64(PACKET_CB(skb)->nonce);
209         pskb_put(skb, trailer, trailer_len);
210
211         /* Now we can encrypt the scattergather segments */
212         sg_init_table(sg, num_frags);
213         if (skb_to_sgvec(skb, sg, sizeof(struct message_data),
214                          noise_encrypted_len(plaintext_len)) <= 0)
215                 return false;
216         return chacha20poly1305_encrypt_sg_inplace(sg, plaintext_len, NULL, 0,
217                                                    PACKET_CB(skb)->nonce,
218                                                    keypair->sending.key);
219 }
220
221 void wg_packet_send_keepalive(struct wg_peer *peer)
222 {
223         struct sk_buff *skb;
224
225         if (skb_queue_empty(&peer->staged_packet_queue)) {
226                 skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH,
227                                 GFP_ATOMIC);
228                 if (unlikely(!skb))
229                         return;
230                 skb_reserve(skb, DATA_PACKET_HEAD_ROOM);
231                 skb->dev = peer->device->dev;
232                 PACKET_CB(skb)->mtu = skb->dev->mtu;
233                 skb_queue_tail(&peer->staged_packet_queue, skb);
234                 net_dbg_ratelimited("%s: Sending keepalive packet to peer %llu (%pISpfsc)\n",
235                                     peer->device->dev->name, peer->internal_id,
236                                     &peer->endpoint.addr);
237         }
238
239         wg_packet_send_staged_packets(peer);
240 }
241
242 static void wg_packet_create_data_done(struct sk_buff *first,
243                                        struct wg_peer *peer)
244 {
245         struct sk_buff *skb, *next;
246         bool is_keepalive, data_sent = false;
247
248         wg_timers_any_authenticated_packet_traversal(peer);
249         wg_timers_any_authenticated_packet_sent(peer);
250         skb_list_walk_safe(first, skb, next) {
251                 is_keepalive = skb->len == message_data_len(0);
252                 if (likely(!wg_socket_send_skb_to_peer(peer, skb,
253                                 PACKET_CB(skb)->ds) && !is_keepalive))
254                         data_sent = true;
255         }
256
257         if (likely(data_sent))
258                 wg_timers_data_sent(peer);
259
260         keep_key_fresh(peer);
261 }
262
263 void wg_packet_tx_worker(struct work_struct *work)
264 {
265         struct crypt_queue *queue = container_of(work, struct crypt_queue,
266                                                  work);
267         struct noise_keypair *keypair;
268         enum packet_state state;
269         struct sk_buff *first;
270         struct wg_peer *peer;
271
272         while ((first = __ptr_ring_peek(&queue->ring)) != NULL &&
273                (state = atomic_read_acquire(&PACKET_CB(first)->state)) !=
274                        PACKET_STATE_UNCRYPTED) {
275                 __ptr_ring_discard_one(&queue->ring);
276                 peer = PACKET_PEER(first);
277                 keypair = PACKET_CB(first)->keypair;
278
279                 if (likely(state == PACKET_STATE_CRYPTED))
280                         wg_packet_create_data_done(first, peer);
281                 else
282                         kfree_skb_list(first);
283
284                 wg_noise_keypair_put(keypair, false);
285                 wg_peer_put(peer);
286                 if (need_resched())
287                         cond_resched();
288         }
289 }
290
291 void wg_packet_encrypt_worker(struct work_struct *work)
292 {
293         struct crypt_queue *queue = container_of(work, struct multicore_worker,
294                                                  work)->ptr;
295         struct sk_buff *first, *skb, *next;
296
297         while ((first = ptr_ring_consume_bh(&queue->ring)) != NULL) {
298                 enum packet_state state = PACKET_STATE_CRYPTED;
299
300                 skb_list_walk_safe(first, skb, next) {
301                         if (likely(encrypt_packet(skb,
302                                         PACKET_CB(first)->keypair))) {
303                                 wg_reset_packet(skb, true);
304                         } else {
305                                 state = PACKET_STATE_DEAD;
306                                 break;
307                         }
308                 }
309                 wg_queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first,
310                                           state);
311                 if (need_resched())
312                         cond_resched();
313         }
314 }
315
316 static void wg_packet_create_data(struct sk_buff *first)
317 {
318         struct wg_peer *peer = PACKET_PEER(first);
319         struct wg_device *wg = peer->device;
320         int ret = -EINVAL;
321
322         rcu_read_lock_bh();
323         if (unlikely(READ_ONCE(peer->is_dead)))
324                 goto err;
325
326         ret = wg_queue_enqueue_per_device_and_peer(&wg->encrypt_queue,
327                                                    &peer->tx_queue, first,
328                                                    wg->packet_crypt_wq,
329                                                    &wg->encrypt_queue.last_cpu);
330         if (unlikely(ret == -EPIPE))
331                 wg_queue_enqueue_per_peer(&peer->tx_queue, first,
332                                           PACKET_STATE_DEAD);
333 err:
334         rcu_read_unlock_bh();
335         if (likely(!ret || ret == -EPIPE))
336                 return;
337         wg_noise_keypair_put(PACKET_CB(first)->keypair, false);
338         wg_peer_put(peer);
339         kfree_skb_list(first);
340 }
341
342 void wg_packet_purge_staged_packets(struct wg_peer *peer)
343 {
344         spin_lock_bh(&peer->staged_packet_queue.lock);
345         peer->device->dev->stats.tx_dropped += peer->staged_packet_queue.qlen;
346         __skb_queue_purge(&peer->staged_packet_queue);
347         spin_unlock_bh(&peer->staged_packet_queue.lock);
348 }
349
350 void wg_packet_send_staged_packets(struct wg_peer *peer)
351 {
352         struct noise_keypair *keypair;
353         struct sk_buff_head packets;
354         struct sk_buff *skb;
355
356         /* Steal the current queue into our local one. */
357         __skb_queue_head_init(&packets);
358         spin_lock_bh(&peer->staged_packet_queue.lock);
359         skb_queue_splice_init(&peer->staged_packet_queue, &packets);
360         spin_unlock_bh(&peer->staged_packet_queue.lock);
361         if (unlikely(skb_queue_empty(&packets)))
362                 return;
363
364         /* First we make sure we have a valid reference to a valid key. */
365         rcu_read_lock_bh();
366         keypair = wg_noise_keypair_get(
367                 rcu_dereference_bh(peer->keypairs.current_keypair));
368         rcu_read_unlock_bh();
369         if (unlikely(!keypair))
370                 goto out_nokey;
371         if (unlikely(!READ_ONCE(keypair->sending.is_valid)))
372                 goto out_nokey;
373         if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
374                                               REJECT_AFTER_TIME)))
375                 goto out_invalid;
376
377         /* After we know we have a somewhat valid key, we now try to assign
378          * nonces to all of the packets in the queue. If we can't assign nonces
379          * for all of them, we just consider it a failure and wait for the next
380          * handshake.
381          */
382         skb_queue_walk(&packets, skb) {
383                 /* 0 for no outer TOS: no leak. TODO: at some later point, we
384                  * might consider using flowi->tos as outer instead.
385                  */
386                 PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb);
387                 PACKET_CB(skb)->nonce =
388                                 atomic64_inc_return(&keypair->sending_counter) - 1;
389                 if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
390                         goto out_invalid;
391         }
392
393         packets.prev->next = NULL;
394         wg_peer_get(keypair->entry.peer);
395         PACKET_CB(packets.next)->keypair = keypair;
396         wg_packet_create_data(packets.next);
397         return;
398
399 out_invalid:
400         WRITE_ONCE(keypair->sending.is_valid, false);
401 out_nokey:
402         wg_noise_keypair_put(keypair, false);
403
404         /* We orphan the packets if we're waiting on a handshake, so that they
405          * don't block a socket's pool.
406          */
407         skb_queue_walk(&packets, skb)
408                 skb_orphan(skb);
409         /* Then we put them back on the top of the queue. We're not too
410          * concerned about accidentally getting things a little out of order if
411          * packets are being added really fast, because this queue is for before
412          * packets can even be sent and it's small anyway.
413          */
414         spin_lock_bh(&peer->staged_packet_queue.lock);
415         skb_queue_splice(&packets, &peer->staged_packet_queue);
416         spin_unlock_bh(&peer->staged_packet_queue.lock);
417
418         /* If we're exiting because there's something wrong with the key, it
419          * means we should initiate a new handshake.
420          */
421         wg_packet_send_queued_handshake_initiation(peer, false);
422 }