Merge branches 'arm/omap', 'arm/exynos', 'arm/smmu', 'arm/mediatek', 'arm/qcom',...
[sfrench/cifs-2.6.git] / net / rxrpc / input.c
index dd47d465d1d3e7de3bcf72dc76611621bb2c0e6a..d122c53c869734ee00ad628fd53c3ae49d143849 100644 (file)
@@ -233,7 +233,7 @@ static bool rxrpc_rotate_tx_window(struct rxrpc_call *call, rxrpc_seq_t to,
                ix = call->tx_hard_ack & RXRPC_RXTX_BUFF_MASK;
                skb = call->rxtx_buffer[ix];
                annotation = call->rxtx_annotations[ix];
-               rxrpc_see_skb(skb, rxrpc_skb_tx_rotated);
+               rxrpc_see_skb(skb, rxrpc_skb_rotated);
                call->rxtx_buffer[ix] = NULL;
                call->rxtx_annotations[ix] = 0;
                skb->next = list;
@@ -258,7 +258,7 @@ static bool rxrpc_rotate_tx_window(struct rxrpc_call *call, rxrpc_seq_t to,
                skb = list;
                list = skb->next;
                skb_mark_not_on_list(skb);
-               rxrpc_free_skb(skb, rxrpc_skb_tx_freed);
+               rxrpc_free_skb(skb, rxrpc_skb_freed);
        }
 
        return rot_last;
@@ -347,7 +347,7 @@ static bool rxrpc_receiving_reply(struct rxrpc_call *call)
 }
 
 /*
- * Scan a jumbo packet to validate its structure and to work out how many
+ * Scan a data packet to validate its structure and to work out how many
  * subpackets it contains.
  *
  * A jumbo packet is a collection of consecutive packets glued together with
@@ -358,16 +358,21 @@ static bool rxrpc_receiving_reply(struct rxrpc_call *call)
  * the last are RXRPC_JUMBO_DATALEN in size.  The last subpacket may be of any
  * size.
  */
-static bool rxrpc_validate_jumbo(struct sk_buff *skb)
+static bool rxrpc_validate_data(struct sk_buff *skb)
 {
        struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
        unsigned int offset = sizeof(struct rxrpc_wire_header);
        unsigned int len = skb->len;
-       int nr_jumbo = 1;
        u8 flags = sp->hdr.flags;
 
-       do {
-               nr_jumbo++;
+       for (;;) {
+               if (flags & RXRPC_REQUEST_ACK)
+                       __set_bit(sp->nr_subpackets, sp->rx_req_ack);
+               sp->nr_subpackets++;
+
+               if (!(flags & RXRPC_JUMBO_PACKET))
+                       break;
+
                if (len - offset < RXRPC_JUMBO_SUBPKTLEN)
                        goto protocol_error;
                if (flags & RXRPC_LAST_PACKET)
@@ -376,9 +381,10 @@ static bool rxrpc_validate_jumbo(struct sk_buff *skb)
                if (skb_copy_bits(skb, offset, &flags, 1) < 0)
                        goto protocol_error;
                offset += sizeof(struct rxrpc_jumbo_header);
-       } while (flags & RXRPC_JUMBO_PACKET);
+       }
 
-       sp->nr_jumbo = nr_jumbo;
+       if (flags & RXRPC_LAST_PACKET)
+               sp->rx_flags |= RXRPC_SKB_INCL_LAST;
        return true;
 
 protocol_error:
@@ -399,10 +405,10 @@ protocol_error:
  * (that information is encoded in the ACK packet).
  */
 static void rxrpc_input_dup_data(struct rxrpc_call *call, rxrpc_seq_t seq,
-                                u8 annotation, bool *_jumbo_bad)
+                                bool is_jumbo, bool *_jumbo_bad)
 {
        /* Discard normal packets that are duplicates. */
-       if (annotation == 0)
+       if (is_jumbo)
                return;
 
        /* Skip jumbo subpackets that are duplicates.  When we've had three or
@@ -416,29 +422,30 @@ static void rxrpc_input_dup_data(struct rxrpc_call *call, rxrpc_seq_t seq,
 }
 
 /*
- * Process a DATA packet, adding the packet to the Rx ring.
+ * Process a DATA packet, adding the packet to the Rx ring.  The caller's
+ * packet ref must be passed on or discarded.
  */
 static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb)
 {
        struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
        enum rxrpc_call_state state;
-       unsigned int offset = sizeof(struct rxrpc_wire_header);
-       unsigned int ix;
+       unsigned int j;
        rxrpc_serial_t serial = sp->hdr.serial, ack_serial = 0;
-       rxrpc_seq_t seq = sp->hdr.seq, hard_ack;
-       bool immediate_ack = false, jumbo_bad = false, queued;
-       u16 len;
-       u8 ack = 0, flags, annotation = 0;
+       rxrpc_seq_t seq0 = sp->hdr.seq, hard_ack;
+       bool immediate_ack = false, jumbo_bad = false;
+       u8 ack = 0;
 
        _enter("{%u,%u},{%u,%u}",
-              call->rx_hard_ack, call->rx_top, skb->len, seq);
+              call->rx_hard_ack, call->rx_top, skb->len, seq0);
 
-       _proto("Rx DATA %%%u { #%u f=%02x }",
-              sp->hdr.serial, seq, sp->hdr.flags);
+       _proto("Rx DATA %%%u { #%u f=%02x n=%u }",
+              sp->hdr.serial, seq0, sp->hdr.flags, sp->nr_subpackets);
 
        state = READ_ONCE(call->state);
-       if (state >= RXRPC_CALL_COMPLETE)
+       if (state >= RXRPC_CALL_COMPLETE) {
+               rxrpc_free_skb(skb, rxrpc_skb_freed);
                return;
+       }
 
        if (call->state == RXRPC_CALL_SERVER_RECV_REQUEST) {
                unsigned long timo = READ_ONCE(call->next_req_timo);
@@ -463,137 +470,137 @@ static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb)
            !rxrpc_receiving_reply(call))
                goto unlock;
 
-       call->ackr_prev_seq = seq;
-
+       call->ackr_prev_seq = seq0;
        hard_ack = READ_ONCE(call->rx_hard_ack);
-       if (after(seq, hard_ack + call->rx_winsize)) {
-               ack = RXRPC_ACK_EXCEEDS_WINDOW;
-               ack_serial = serial;
-               goto ack;
-       }
 
-       flags = sp->hdr.flags;
-       if (flags & RXRPC_JUMBO_PACKET) {
+       if (sp->nr_subpackets > 1) {
                if (call->nr_jumbo_bad > 3) {
                        ack = RXRPC_ACK_NOSPACE;
                        ack_serial = serial;
                        goto ack;
                }
-               annotation = 1;
        }
 
-next_subpacket:
-       queued = false;
-       ix = seq & RXRPC_RXTX_BUFF_MASK;
-       len = skb->len;
-       if (flags & RXRPC_JUMBO_PACKET)
-               len = RXRPC_JUMBO_DATALEN;
-
-       if (flags & RXRPC_LAST_PACKET) {
-               if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) &&
-                   seq != call->rx_top) {
-                       rxrpc_proto_abort("LSN", call, seq);
-                       goto unlock;
-               }
-       } else {
-               if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) &&
-                   after_eq(seq, call->rx_top)) {
-                       rxrpc_proto_abort("LSA", call, seq);
-                       goto unlock;
+       for (j = 0; j < sp->nr_subpackets; j++) {
+               rxrpc_serial_t serial = sp->hdr.serial + j;
+               rxrpc_seq_t seq = seq0 + j;
+               unsigned int ix = seq & RXRPC_RXTX_BUFF_MASK;
+               bool terminal = (j == sp->nr_subpackets - 1);
+               bool last = terminal && (sp->rx_flags & RXRPC_SKB_INCL_LAST);
+               u8 flags, annotation = j;
+
+               _proto("Rx DATA+%u %%%u { #%x t=%u l=%u }",
+                    j, serial, seq, terminal, last);
+
+               if (last) {
+                       if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) &&
+                           seq != call->rx_top) {
+                               rxrpc_proto_abort("LSN", call, seq);
+                               goto unlock;
+                       }
+               } else {
+                       if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) &&
+                           after_eq(seq, call->rx_top)) {
+                               rxrpc_proto_abort("LSA", call, seq);
+                               goto unlock;
+                       }
                }
-       }
 
-       trace_rxrpc_rx_data(call->debug_id, seq, serial, flags, annotation);
-       if (before_eq(seq, hard_ack)) {
-               ack = RXRPC_ACK_DUPLICATE;
-               ack_serial = serial;
-               goto skip;
-       }
+               flags = 0;
+               if (last)
+                       flags |= RXRPC_LAST_PACKET;
+               if (!terminal)
+                       flags |= RXRPC_JUMBO_PACKET;
+               if (test_bit(j, sp->rx_req_ack))
+                       flags |= RXRPC_REQUEST_ACK;
+               trace_rxrpc_rx_data(call->debug_id, seq, serial, flags, annotation);
 
-       if (flags & RXRPC_REQUEST_ACK && !ack) {
-               ack = RXRPC_ACK_REQUESTED;
-               ack_serial = serial;
-       }
-
-       if (call->rxtx_buffer[ix]) {
-               rxrpc_input_dup_data(call, seq, annotation, &jumbo_bad);
-               if (ack != RXRPC_ACK_DUPLICATE) {
+               if (before_eq(seq, hard_ack)) {
                        ack = RXRPC_ACK_DUPLICATE;
                        ack_serial = serial;
+                       continue;
                }
-               immediate_ack = true;
-               goto skip;
-       }
 
-       /* Queue the packet.  We use a couple of memory barriers here as need
-        * to make sure that rx_top is perceived to be set after the buffer
-        * pointer and that the buffer pointer is set after the annotation and
-        * the skb data.
-        *
-        * Barriers against rxrpc_recvmsg_data() and rxrpc_rotate_rx_window()
-        * and also rxrpc_fill_out_ack().
-        */
-       rxrpc_get_skb(skb, rxrpc_skb_rx_got);
-       call->rxtx_annotations[ix] = annotation;
-       smp_wmb();
-       call->rxtx_buffer[ix] = skb;
-       if (after(seq, call->rx_top)) {
-               smp_store_release(&call->rx_top, seq);
-       } else if (before(seq, call->rx_top)) {
-               /* Send an immediate ACK if we fill in a hole */
-               if (!ack) {
-                       ack = RXRPC_ACK_DELAY;
-                       ack_serial = serial;
+               if (call->rxtx_buffer[ix]) {
+                       rxrpc_input_dup_data(call, seq, sp->nr_subpackets > 1,
+                                            &jumbo_bad);
+                       if (ack != RXRPC_ACK_DUPLICATE) {
+                               ack = RXRPC_ACK_DUPLICATE;
+                               ack_serial = serial;
+                       }
+                       immediate_ack = true;
+                       continue;
                }
-               immediate_ack = true;
-       }
-       if (flags & RXRPC_LAST_PACKET) {
-               set_bit(RXRPC_CALL_RX_LAST, &call->flags);
-               trace_rxrpc_receive(call, rxrpc_receive_queue_last, serial, seq);
-       } else {
-               trace_rxrpc_receive(call, rxrpc_receive_queue, serial, seq);
-       }
-       queued = true;
 
-       if (after_eq(seq, call->rx_expect_next)) {
-               if (after(seq, call->rx_expect_next)) {
-                       _net("OOS %u > %u", seq, call->rx_expect_next);
-                       ack = RXRPC_ACK_OUT_OF_SEQUENCE;
-                       ack_serial = serial;
-               }
-               call->rx_expect_next = seq + 1;
-       }
-
-skip:
-       offset += len;
-       if (flags & RXRPC_JUMBO_PACKET) {
-               if (skb_copy_bits(skb, offset, &flags, 1) < 0) {
-                       rxrpc_proto_abort("XJF", call, seq);
-                       goto unlock;
-               }
-               offset += sizeof(struct rxrpc_jumbo_header);
-               seq++;
-               serial++;
-               annotation++;
-               if (flags & RXRPC_JUMBO_PACKET)
-                       annotation |= RXRPC_RX_ANNO_JLAST;
                if (after(seq, hard_ack + call->rx_winsize)) {
                        ack = RXRPC_ACK_EXCEEDS_WINDOW;
                        ack_serial = serial;
-                       if (!jumbo_bad) {
-                               call->nr_jumbo_bad++;
-                               jumbo_bad = true;
+                       if (flags & RXRPC_JUMBO_PACKET) {
+                               if (!jumbo_bad) {
+                                       call->nr_jumbo_bad++;
+                                       jumbo_bad = true;
+                               }
                        }
+
                        goto ack;
                }
 
-               _proto("Rx DATA Jumbo %%%u", serial);
-               goto next_subpacket;
-       }
+               if (flags & RXRPC_REQUEST_ACK && !ack) {
+                       ack = RXRPC_ACK_REQUESTED;
+                       ack_serial = serial;
+               }
+
+               /* Queue the packet.  We use a couple of memory barriers here as need
+                * to make sure that rx_top is perceived to be set after the buffer
+                * pointer and that the buffer pointer is set after the annotation and
+                * the skb data.
+                *
+                * Barriers against rxrpc_recvmsg_data() and rxrpc_rotate_rx_window()
+                * and also rxrpc_fill_out_ack().
+                */
+               if (!terminal)
+                       rxrpc_get_skb(skb, rxrpc_skb_got);
+               call->rxtx_annotations[ix] = annotation;
+               smp_wmb();
+               call->rxtx_buffer[ix] = skb;
+               if (after(seq, call->rx_top)) {
+                       smp_store_release(&call->rx_top, seq);
+               } else if (before(seq, call->rx_top)) {
+                       /* Send an immediate ACK if we fill in a hole */
+                       if (!ack) {
+                               ack = RXRPC_ACK_DELAY;
+                               ack_serial = serial;
+                       }
+                       immediate_ack = true;
+               }
+
+               if (terminal) {
+                       /* From this point on, we're not allowed to touch the
+                        * packet any longer as its ref now belongs to the Rx
+                        * ring.
+                        */
+                       skb = NULL;
+               }
 
-       if (queued && flags & RXRPC_LAST_PACKET && !ack) {
-               ack = RXRPC_ACK_DELAY;
-               ack_serial = serial;
+               if (last) {
+                       set_bit(RXRPC_CALL_RX_LAST, &call->flags);
+                       if (!ack) {
+                               ack = RXRPC_ACK_DELAY;
+                               ack_serial = serial;
+                       }
+                       trace_rxrpc_receive(call, rxrpc_receive_queue_last, serial, seq);
+               } else {
+                       trace_rxrpc_receive(call, rxrpc_receive_queue, serial, seq);
+               }
+
+               if (after_eq(seq, call->rx_expect_next)) {
+                       if (after(seq, call->rx_expect_next)) {
+                               _net("OOS %u > %u", seq, call->rx_expect_next);
+                               ack = RXRPC_ACK_OUT_OF_SEQUENCE;
+                               ack_serial = serial;
+                       }
+                       call->rx_expect_next = seq + 1;
+               }
        }
 
 ack:
@@ -606,13 +613,14 @@ ack:
                                  false, true,
                                  rxrpc_propose_ack_input_data);
 
-       if (sp->hdr.seq == READ_ONCE(call->rx_hard_ack) + 1) {
+       if (seq0 == READ_ONCE(call->rx_hard_ack) + 1) {
                trace_rxrpc_notify_socket(call->debug_id, serial);
                rxrpc_notify_socket(call);
        }
 
 unlock:
        spin_unlock(&call->input_lock);
+       rxrpc_free_skb(skb, rxrpc_skb_freed);
        _leave(" [queued]");
 }
 
@@ -1021,7 +1029,7 @@ static void rxrpc_input_call_packet(struct rxrpc_call *call,
        switch (sp->hdr.type) {
        case RXRPC_PACKET_TYPE_DATA:
                rxrpc_input_data(call, skb);
-               break;
+               goto no_free;
 
        case RXRPC_PACKET_TYPE_ACK:
                rxrpc_input_ack(call, skb);
@@ -1048,6 +1056,8 @@ static void rxrpc_input_call_packet(struct rxrpc_call *call,
                break;
        }
 
+       rxrpc_free_skb(skb, rxrpc_skb_freed);
+no_free:
        _leave("");
 }
 
@@ -1109,7 +1119,7 @@ static void rxrpc_post_packet_to_local(struct rxrpc_local *local,
                skb_queue_tail(&local->event_queue, skb);
                rxrpc_queue_local(local);
        } else {
-               rxrpc_free_skb(skb, rxrpc_skb_rx_freed);
+               rxrpc_free_skb(skb, rxrpc_skb_freed);
        }
 }
 
@@ -1124,7 +1134,7 @@ static void rxrpc_reject_packet(struct rxrpc_local *local, struct sk_buff *skb)
                skb_queue_tail(&local->reject_queue, skb);
                rxrpc_queue_local(local);
        } else {
-               rxrpc_free_skb(skb, rxrpc_skb_rx_freed);
+               rxrpc_free_skb(skb, rxrpc_skb_freed);
        }
 }
 
@@ -1188,7 +1198,7 @@ int rxrpc_input_packet(struct sock *udp_sk, struct sk_buff *skb)
        if (skb->tstamp == 0)
                skb->tstamp = ktime_get_real();
 
-       rxrpc_new_skb(skb, rxrpc_skb_rx_received);
+       rxrpc_new_skb(skb, rxrpc_skb_received);
 
        skb_pull(skb, sizeof(struct udphdr));
 
@@ -1205,7 +1215,7 @@ int rxrpc_input_packet(struct sock *udp_sk, struct sk_buff *skb)
                static int lose;
                if ((lose++ & 7) == 7) {
                        trace_rxrpc_rx_lose(sp);
-                       rxrpc_free_skb(skb, rxrpc_skb_rx_lost);
+                       rxrpc_free_skb(skb, rxrpc_skb_lost);
                        return 0;
                }
        }
@@ -1237,9 +1247,26 @@ int rxrpc_input_packet(struct sock *udp_sk, struct sk_buff *skb)
                if (sp->hdr.callNumber == 0 ||
                    sp->hdr.seq == 0)
                        goto bad_message;
-               if (sp->hdr.flags & RXRPC_JUMBO_PACKET &&
-                   !rxrpc_validate_jumbo(skb))
+               if (!rxrpc_validate_data(skb))
                        goto bad_message;
+
+               /* Unshare the packet so that it can be modified for in-place
+                * decryption.
+                */
+               if (sp->hdr.securityIndex != 0) {
+                       struct sk_buff *nskb = skb_unshare(skb, GFP_ATOMIC);
+                       if (!nskb) {
+                               rxrpc_eaten_skb(skb, rxrpc_skb_unshared_nomem);
+                               goto out;
+                       }
+
+                       if (nskb != skb) {
+                               rxrpc_eaten_skb(skb, rxrpc_skb_received);
+                               rxrpc_new_skb(skb, rxrpc_skb_unshared);
+                               skb = nskb;
+                               sp = rxrpc_skb(skb);
+                       }
+               }
                break;
 
        case RXRPC_PACKET_TYPE_CHALLENGE:
@@ -1373,11 +1400,14 @@ int rxrpc_input_packet(struct sock *udp_sk, struct sk_buff *skb)
                mutex_unlock(&call->user_mutex);
        }
 
+       /* Process a call packet; this either discards or passes on the ref
+        * elsewhere.
+        */
        rxrpc_input_call_packet(call, skb);
-       goto discard;
+       goto out;
 
 discard:
-       rxrpc_free_skb(skb, rxrpc_skb_rx_freed);
+       rxrpc_free_skb(skb, rxrpc_skb_freed);
 out:
        trace_rxrpc_rx_done(0, 0);
        return 0;