bridge: simplify ip_mc_check_igmp() and ipv6_mc_check_mld() calls
authorLinus Lüssing <linus.luessing@c0d3.blue>
Mon, 21 Jan 2019 06:26:25 +0000 (07:26 +0100)
committerDavid S. Miller <davem@davemloft.net>
Wed, 23 Jan 2019 01:18:08 +0000 (17:18 -0800)
This patch refactors ip_mc_check_igmp(), ipv6_mc_check_mld() and
their callers (more precisely, the Linux bridge) to not rely on
the skb_trimmed parameter anymore.

An skb with its tail trimmed to the IP packet length was initially
introduced for the following three reasons:

1) To be able to verify the ICMPv6 checksum.
2) To be able to distinguish the version of an IGMP or MLD query.
   They are distinguishable only by their size.
3) To avoid parsing data for an IGMPv3 or MLDv2 report that is
   beyond the IP packet but still within the skb.

The first case still uses a cloned and potentially trimmed skb to
verfiy. However, there is no need to propagate it to the caller.
For the second and third case explicit IP packet length checks were
added.

This hopefully makes ip_mc_check_igmp() and ipv6_mc_check_mld() easier
to read and verfiy, as well as easier to use.

Signed-off-by: Linus Lüssing <linus.luessing@c0d3.blue>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/linux/igmp.h
include/linux/ip.h
include/linux/ipv6.h
include/net/addrconf.h
net/batman-adv/multicast.c
net/bridge/br_multicast.c
net/ipv4/igmp.c
net/ipv6/mcast_snoop.c

index 119f53941c124c22452bf615f9ccca5a9130bb87..8b4348f69bc5c119093b9992924a5933ef3142a7 100644 (file)
@@ -18,6 +18,7 @@
 #include <linux/skbuff.h>
 #include <linux/timer.h>
 #include <linux/in.h>
+#include <linux/ip.h>
 #include <linux/refcount.h>
 #include <uapi/linux/igmp.h>
 
@@ -106,6 +107,14 @@ struct ip_mc_list {
 #define IGMPV3_QQIC(value) IGMPV3_EXP(0x80, 4, 3, value)
 #define IGMPV3_MRC(value) IGMPV3_EXP(0x80, 4, 3, value)
 
+static inline int ip_mc_may_pull(struct sk_buff *skb, unsigned int len)
+{
+       if (skb_transport_offset(skb) + ip_transport_len(skb) < len)
+               return -EINVAL;
+
+       return pskb_may_pull(skb, len);
+}
+
 extern int ip_check_mc_rcu(struct in_device *dev, __be32 mc_addr, __be32 src_addr, u8 proto);
 extern int igmp_rcv(struct sk_buff *);
 extern int ip_mc_join_group(struct sock *sk, struct ip_mreqn *imr);
@@ -130,6 +139,6 @@ extern void ip_mc_unmap(struct in_device *);
 extern void ip_mc_remap(struct in_device *);
 extern void ip_mc_dec_group(struct in_device *in_dev, __be32 addr);
 extern void ip_mc_inc_group(struct in_device *in_dev, __be32 addr);
-int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed);
+int ip_mc_check_igmp(struct sk_buff *skb);
 
 #endif
index 492bc6513533cf8d09fdfe667711ab0b78c3a0d2..482b7b7c9f30c3d4b2aa15b07238dc57fa7d5649 100644 (file)
@@ -34,4 +34,9 @@ static inline struct iphdr *ipip_hdr(const struct sk_buff *skb)
 {
        return (struct iphdr *)skb_transport_header(skb);
 }
+
+static inline unsigned int ip_transport_len(const struct sk_buff *skb)
+{
+       return ntohs(ip_hdr(skb)->tot_len) - skb_network_header_len(skb);
+}
 #endif /* _LINUX_IP_H */
index 495e834c1367ddc9e680f56d0cc66f4055fd1f6e..6d45ce784beaaf3ae5a3f29132c3b146221a6896 100644 (file)
@@ -104,6 +104,12 @@ static inline struct ipv6hdr *ipipv6_hdr(const struct sk_buff *skb)
        return (struct ipv6hdr *)skb_transport_header(skb);
 }
 
+static inline unsigned int ipv6_transport_len(const struct sk_buff *skb)
+{
+       return ntohs(ipv6_hdr(skb)->payload_len) + sizeof(struct ipv6hdr) -
+              skb_network_header_len(skb);
+}
+
 /* 
    This structure contains results of exthdrs parsing
    as offsets from skb->nh.
index 1656c59784987bd486ace6be1f10705fb47ac5c6..daf11dcb0f70918837463e5144183281cdb0bd11 100644 (file)
@@ -49,6 +49,7 @@ struct prefix_info {
        struct in6_addr         prefix;
 };
 
+#include <linux/ipv6.h>
 #include <linux/netdevice.h>
 #include <net/if_inet6.h>
 #include <net/ipv6.h>
@@ -201,6 +202,15 @@ u32 ipv6_addr_label(struct net *net, const struct in6_addr *addr,
 /*
  *     multicast prototypes (mcast.c)
  */
+static inline int ipv6_mc_may_pull(struct sk_buff *skb,
+                                  unsigned int len)
+{
+       if (skb_transport_offset(skb) + ipv6_transport_len(skb) < len)
+               return -EINVAL;
+
+       return pskb_may_pull(skb, len);
+}
+
 int ipv6_sock_mc_join(struct sock *sk, int ifindex,
                      const struct in6_addr *addr);
 int ipv6_sock_mc_drop(struct sock *sk, int ifindex,
@@ -219,7 +229,7 @@ void ipv6_mc_unmap(struct inet6_dev *idev);
 void ipv6_mc_remap(struct inet6_dev *idev);
 void ipv6_mc_init_dev(struct inet6_dev *idev);
 void ipv6_mc_destroy_dev(struct inet6_dev *idev);
-int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed);
+int ipv6_mc_check_mld(struct sk_buff *skb);
 void addrconf_dad_failure(struct sk_buff *skb, struct inet6_ifaddr *ifp);
 
 bool ipv6_chk_mcast_addr(struct net_device *dev, const struct in6_addr *group,
index 69244e4598f5a4df7e0d136e6ce3181b02a085ef..1dd70f048e7b5c4f65f7f9d42038b53b31f64ae2 100644 (file)
@@ -674,7 +674,7 @@ static void batadv_mcast_mla_update(struct work_struct *work)
  */
 static bool batadv_mcast_is_report_ipv4(struct sk_buff *skb)
 {
-       if (ip_mc_check_igmp(skb, NULL) < 0)
+       if (ip_mc_check_igmp(skb) < 0)
                return false;
 
        switch (igmp_hdr(skb)->type) {
@@ -741,7 +741,7 @@ static int batadv_mcast_forw_mode_check_ipv4(struct batadv_priv *bat_priv,
  */
 static bool batadv_mcast_is_report_ipv6(struct sk_buff *skb)
 {
-       if (ipv6_mc_check_mld(skb, NULL) < 0)
+       if (ipv6_mc_check_mld(skb) < 0)
                return false;
 
        switch (icmp6_hdr(skb)->icmp6_type) {
index 3aeff0895669609b753607abb362fc7bbbb7f28a..156c4905639edaccc5f6c82c9539cfae089738b2 100644 (file)
@@ -938,7 +938,7 @@ static int br_ip4_multicast_igmp3_report(struct net_bridge *br,
 
        for (i = 0; i < num; i++) {
                len += sizeof(*grec);
-               if (!pskb_may_pull(skb, len))
+               if (!ip_mc_may_pull(skb, len))
                        return -EINVAL;
 
                grec = (void *)(skb->data + len - sizeof(*grec));
@@ -946,7 +946,7 @@ static int br_ip4_multicast_igmp3_report(struct net_bridge *br,
                type = grec->grec_type;
 
                len += ntohs(grec->grec_nsrcs) * 4;
-               if (!pskb_may_pull(skb, len))
+               if (!ip_mc_may_pull(skb, len))
                        return -EINVAL;
 
                /* We treat this as an IGMPv2 report for now. */
@@ -985,15 +985,17 @@ static int br_ip6_multicast_mld2_report(struct net_bridge *br,
                                        struct sk_buff *skb,
                                        u16 vid)
 {
+       unsigned int nsrcs_offset;
        const unsigned char *src;
        struct icmp6hdr *icmp6h;
        struct mld2_grec *grec;
+       unsigned int grec_len;
        int i;
        int len;
        int num;
        int err = 0;
 
-       if (!pskb_may_pull(skb, sizeof(*icmp6h)))
+       if (!ipv6_mc_may_pull(skb, sizeof(*icmp6h)))
                return -EINVAL;
 
        icmp6h = icmp6_hdr(skb);
@@ -1003,21 +1005,25 @@ static int br_ip6_multicast_mld2_report(struct net_bridge *br,
        for (i = 0; i < num; i++) {
                __be16 *nsrcs, _nsrcs;
 
-               nsrcs = skb_header_pointer(skb,
-                                          len + offsetof(struct mld2_grec,
-                                                         grec_nsrcs),
+               nsrcs_offset = len + offsetof(struct mld2_grec, grec_nsrcs);
+
+               if (skb_transport_offset(skb) + ipv6_transport_len(skb) <
+                   nsrcs_offset + sizeof(_nsrcs))
+                       return -EINVAL;
+
+               nsrcs = skb_header_pointer(skb, nsrcs_offset,
                                           sizeof(_nsrcs), &_nsrcs);
                if (!nsrcs)
                        return -EINVAL;
 
-               if (!pskb_may_pull(skb,
-                                  len + sizeof(*grec) +
-                                  sizeof(struct in6_addr) * ntohs(*nsrcs)))
+               grec_len = sizeof(*grec) +
+                          sizeof(struct in6_addr) * ntohs(*nsrcs);
+
+               if (!ipv6_mc_may_pull(skb, len + grec_len))
                        return -EINVAL;
 
                grec = (struct mld2_grec *)(skb->data + len);
-               len += sizeof(*grec) +
-                      sizeof(struct in6_addr) * ntohs(*nsrcs);
+               len += grec_len;
 
                /* We treat these as MLDv1 reports for now. */
                switch (grec->grec_type) {
@@ -1219,6 +1225,7 @@ static void br_ip4_multicast_query(struct net_bridge *br,
                                   struct sk_buff *skb,
                                   u16 vid)
 {
+       unsigned int transport_len = ip_transport_len(skb);
        const struct iphdr *iph = ip_hdr(skb);
        struct igmphdr *ih = igmp_hdr(skb);
        struct net_bridge_mdb_entry *mp;
@@ -1228,7 +1235,6 @@ static void br_ip4_multicast_query(struct net_bridge *br,
        struct br_ip saddr;
        unsigned long max_delay;
        unsigned long now = jiffies;
-       unsigned int offset = skb_transport_offset(skb);
        __be32 group;
 
        spin_lock(&br->multicast_lock);
@@ -1238,14 +1244,14 @@ static void br_ip4_multicast_query(struct net_bridge *br,
 
        group = ih->group;
 
-       if (skb->len == offset + sizeof(*ih)) {
+       if (transport_len == sizeof(*ih)) {
                max_delay = ih->code * (HZ / IGMP_TIMER_SCALE);
 
                if (!max_delay) {
                        max_delay = 10 * HZ;
                        group = 0;
                }
-       } else if (skb->len >= offset + sizeof(*ih3)) {
+       } else if (transport_len >= sizeof(*ih3)) {
                ih3 = igmpv3_query_hdr(skb);
                if (ih3->nsrcs)
                        goto out;
@@ -1296,6 +1302,7 @@ static int br_ip6_multicast_query(struct net_bridge *br,
                                  struct sk_buff *skb,
                                  u16 vid)
 {
+       unsigned int transport_len = ipv6_transport_len(skb);
        const struct ipv6hdr *ip6h = ipv6_hdr(skb);
        struct mld_msg *mld;
        struct net_bridge_mdb_entry *mp;
@@ -1315,7 +1322,7 @@ static int br_ip6_multicast_query(struct net_bridge *br,
            (port && port->state == BR_STATE_DISABLED))
                goto out;
 
-       if (skb->len == offset + sizeof(*mld)) {
+       if (transport_len == sizeof(*mld)) {
                if (!pskb_may_pull(skb, offset + sizeof(*mld))) {
                        err = -EINVAL;
                        goto out;
@@ -1581,12 +1588,11 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
                                 struct sk_buff *skb,
                                 u16 vid)
 {
-       struct sk_buff *skb_trimmed = NULL;
        const unsigned char *src;
        struct igmphdr *ih;
        int err;
 
-       err = ip_mc_check_igmp(skb, &skb_trimmed);
+       err = ip_mc_check_igmp(skb);
 
        if (err == -ENOMSG) {
                if (!ipv4_is_local_multicast(ip_hdr(skb)->daddr)) {
@@ -1612,19 +1618,16 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
                err = br_ip4_multicast_add_group(br, port, ih->group, vid, src);
                break;
        case IGMPV3_HOST_MEMBERSHIP_REPORT:
-               err = br_ip4_multicast_igmp3_report(br, port, skb_trimmed, vid);
+               err = br_ip4_multicast_igmp3_report(br, port, skb, vid);
                break;
        case IGMP_HOST_MEMBERSHIP_QUERY:
-               br_ip4_multicast_query(br, port, skb_trimmed, vid);
+               br_ip4_multicast_query(br, port, skb, vid);
                break;
        case IGMP_HOST_LEAVE_MESSAGE:
                br_ip4_multicast_leave_group(br, port, ih->group, vid, src);
                break;
        }
 
-       if (skb_trimmed && skb_trimmed != skb)
-               kfree_skb(skb_trimmed);
-
        br_multicast_count(br, port, skb, BR_INPUT_SKB_CB(skb)->igmp,
                           BR_MCAST_DIR_RX);
 
@@ -1637,12 +1640,11 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
                                 struct sk_buff *skb,
                                 u16 vid)
 {
-       struct sk_buff *skb_trimmed = NULL;
        const unsigned char *src;
        struct mld_msg *mld;
        int err;
 
-       err = ipv6_mc_check_mld(skb, &skb_trimmed);
+       err = ipv6_mc_check_mld(skb);
 
        if (err == -ENOMSG) {
                if (!ipv6_addr_is_ll_all_nodes(&ipv6_hdr(skb)->daddr))
@@ -1664,10 +1666,10 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
                                                 src);
                break;
        case ICMPV6_MLD2_REPORT:
-               err = br_ip6_multicast_mld2_report(br, port, skb_trimmed, vid);
+               err = br_ip6_multicast_mld2_report(br, port, skb, vid);
                break;
        case ICMPV6_MGM_QUERY:
-               err = br_ip6_multicast_query(br, port, skb_trimmed, vid);
+               err = br_ip6_multicast_query(br, port, skb, vid);
                break;
        case ICMPV6_MGM_REDUCTION:
                src = eth_hdr(skb)->h_source;
@@ -1675,9 +1677,6 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
                break;
        }
 
-       if (skb_trimmed && skb_trimmed != skb)
-               kfree_skb(skb_trimmed);
-
        br_multicast_count(br, port, skb, BR_INPUT_SKB_CB(skb)->igmp,
                           BR_MCAST_DIR_RX);
 
index 765b2b32c4a4263640563f34b4dd93b5bdf471de..b1f6d93282d7fbe96c6020d64f3f97f608a06399 100644 (file)
@@ -1544,7 +1544,7 @@ static inline __sum16 ip_mc_validate_checksum(struct sk_buff *skb)
        return skb_checksum_simple_validate(skb);
 }
 
-static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
+static int __ip_mc_check_igmp(struct sk_buff *skb)
 
 {
        struct sk_buff *skb_chk;
@@ -1566,16 +1566,10 @@ static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
        if (ret)
                goto err;
 
-       if (skb_trimmed)
-               *skb_trimmed = skb_chk;
-       /* free now unneeded clone */
-       else if (skb_chk != skb)
-               kfree_skb(skb_chk);
-
        ret = 0;
 
 err:
-       if (ret && skb_chk && skb_chk != skb)
+       if (skb_chk && skb_chk != skb)
                kfree_skb(skb_chk);
 
        return ret;
@@ -1584,7 +1578,6 @@ err:
 /**
  * ip_mc_check_igmp - checks whether this is a sane IGMP packet
  * @skb: the skb to validate
- * @skb_trimmed: to store an skb pointer trimmed to IPv4 packet tail (optional)
  *
  * Checks whether an IPv4 packet is a valid IGMP packet. If so sets
  * skb transport header accordingly and returns zero.
@@ -1594,18 +1587,10 @@ err:
  * -ENOMSG: IP header validation succeeded but it is not an IGMP packet.
  * -ENOMEM: A memory allocation failure happened.
  *
- * Optionally, an skb pointer might be provided via skb_trimmed (or set it
- * to NULL): After parsing an IGMP packet successfully it will point to
- * an skb which has its tail aligned to the IP packet end. This might
- * either be the originally provided skb or a trimmed, cloned version if
- * the skb frame had data beyond the IP packet. A cloned skb allows us
- * to leave the original skb and its full frame unchanged (which might be
- * desirable for layer 2 frame jugglers).
- *
  * Caller needs to set the skb network header and free any returned skb if it
  * differs from the provided skb.
  */
-int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
+int ip_mc_check_igmp(struct sk_buff *skb)
 {
        int ret = ip_mc_check_iphdr(skb);
 
@@ -1615,7 +1600,7 @@ int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
        if (ip_hdr(skb)->protocol != IPPROTO_IGMP)
                return -ENOMSG;
 
-       return __ip_mc_check_igmp(skb, skb_trimmed);
+       return __ip_mc_check_igmp(skb);
 }
 EXPORT_SYMBOL(ip_mc_check_igmp);
 
index 9405b04eecc64f478960329da93f6e01d437954e..1a917dc80d5ed51f589b3724d0056e01ec7c7d6a 100644 (file)
@@ -136,8 +136,7 @@ static inline __sum16 ipv6_mc_validate_checksum(struct sk_buff *skb)
        return skb_checksum_validate(skb, IPPROTO_ICMPV6, ip6_compute_pseudo);
 }
 
-static int __ipv6_mc_check_mld(struct sk_buff *skb,
-                              struct sk_buff **skb_trimmed)
+static int __ipv6_mc_check_mld(struct sk_buff *skb)
 
 {
        struct sk_buff *skb_chk = NULL;
@@ -160,16 +159,10 @@ static int __ipv6_mc_check_mld(struct sk_buff *skb,
        if (ret)
                goto err;
 
-       if (skb_trimmed)
-               *skb_trimmed = skb_chk;
-       /* free now unneeded clone */
-       else if (skb_chk != skb)
-               kfree_skb(skb_chk);
-
        ret = 0;
 
 err:
-       if (ret && skb_chk && skb_chk != skb)
+       if (skb_chk && skb_chk != skb)
                kfree_skb(skb_chk);
 
        return ret;
@@ -178,7 +171,6 @@ err:
 /**
  * ipv6_mc_check_mld - checks whether this is a sane MLD packet
  * @skb: the skb to validate
- * @skb_trimmed: to store an skb pointer trimmed to IPv6 packet tail (optional)
  *
  * Checks whether an IPv6 packet is a valid MLD packet. If so sets
  * skb transport header accordingly and returns zero.
@@ -188,18 +180,10 @@ err:
  * -ENOMSG: IP header validation succeeded but it is not an MLD packet.
  * -ENOMEM: A memory allocation failure happened.
  *
- * Optionally, an skb pointer might be provided via skb_trimmed (or set it
- * to NULL): After parsing an MLD packet successfully it will point to
- * an skb which has its tail aligned to the IP packet end. This might
- * either be the originally provided skb or a trimmed, cloned version if
- * the skb frame had data beyond the IP packet. A cloned skb allows us
- * to leave the original skb and its full frame unchanged (which might be
- * desirable for layer 2 frame jugglers).
- *
  * Caller needs to set the skb network header and free any returned skb if it
  * differs from the provided skb.
  */
-int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed)
+int ipv6_mc_check_mld(struct sk_buff *skb)
 {
        int ret;
 
@@ -211,6 +195,6 @@ int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed)
        if (ret < 0)
                return ret;
 
-       return __ipv6_mc_check_mld(skb, skb_trimmed);
+       return __ipv6_mc_check_mld(skb);
 }
 EXPORT_SYMBOL(ipv6_mc_check_mld);