RDMA/cma: Allow UD qp_type to join multicast only
[sfrench/cifs-2.6.git] / drivers / infiniband / core / cma.c
index 3081559377133e4df5f55789435130e48c415d52..6b9563d4f23c94dfab6dbff02a601942a7b315d4 100644 (file)
@@ -624,22 +624,11 @@ static inline unsigned short cma_family(struct rdma_id_private *id_priv)
        return id_priv->id.route.addr.src_addr.ss_family;
 }
 
-static int cma_set_qkey(struct rdma_id_private *id_priv, u32 qkey)
+static int cma_set_default_qkey(struct rdma_id_private *id_priv)
 {
        struct ib_sa_mcmember_rec rec;
        int ret = 0;
 
-       if (id_priv->qkey) {
-               if (qkey && id_priv->qkey != qkey)
-                       return -EINVAL;
-               return 0;
-       }
-
-       if (qkey) {
-               id_priv->qkey = qkey;
-               return 0;
-       }
-
        switch (id_priv->id.ps) {
        case RDMA_PS_UDP:
        case RDMA_PS_IB:
@@ -659,6 +648,16 @@ static int cma_set_qkey(struct rdma_id_private *id_priv, u32 qkey)
        return ret;
 }
 
+static int cma_set_qkey(struct rdma_id_private *id_priv, u32 qkey)
+{
+       if (!qkey ||
+           (id_priv->qkey && (id_priv->qkey != qkey)))
+               return -EINVAL;
+
+       id_priv->qkey = qkey;
+       return 0;
+}
+
 static void cma_translate_ib(struct sockaddr_ib *sib, struct rdma_dev_addr *dev_addr)
 {
        dev_addr->dev_type = ARPHRD_INFINIBAND;
@@ -1229,7 +1228,7 @@ static int cma_ib_init_qp_attr(struct rdma_id_private *id_priv,
        *qp_attr_mask = IB_QP_STATE | IB_QP_PKEY_INDEX | IB_QP_PORT;
 
        if (id_priv->id.qp_type == IB_QPT_UD) {
-               ret = cma_set_qkey(id_priv, 0);
+               ret = cma_set_default_qkey(id_priv);
                if (ret)
                        return ret;
 
@@ -4569,7 +4568,10 @@ static int cma_send_sidr_rep(struct rdma_id_private *id_priv,
        memset(&rep, 0, sizeof rep);
        rep.status = status;
        if (status == IB_SIDR_SUCCESS) {
-               ret = cma_set_qkey(id_priv, qkey);
+               if (qkey)
+                       ret = cma_set_qkey(id_priv, qkey);
+               else
+                       ret = cma_set_default_qkey(id_priv);
                if (ret)
                        return ret;
                rep.qp_num = id_priv->qp_num;
@@ -4774,9 +4776,7 @@ static void cma_make_mc_event(int status, struct rdma_id_private *id_priv,
        enum ib_gid_type gid_type;
        struct net_device *ndev;
 
-       if (!status)
-               status = cma_set_qkey(id_priv, be32_to_cpu(multicast->rec.qkey));
-       else
+       if (status)
                pr_debug_ratelimited("RDMA CM: MULTICAST_ERROR: failed to join multicast. status %d\n",
                                     status);
 
@@ -4804,7 +4804,7 @@ static void cma_make_mc_event(int status, struct rdma_id_private *id_priv,
        }
 
        event->param.ud.qp_num = 0xFFFFFF;
-       event->param.ud.qkey = be32_to_cpu(multicast->rec.qkey);
+       event->param.ud.qkey = id_priv->qkey;
 
 out:
        if (ndev)
@@ -4823,8 +4823,11 @@ static int cma_ib_mc_handler(int status, struct ib_sa_multicast *multicast)
            READ_ONCE(id_priv->state) == RDMA_CM_DESTROYING)
                goto out;
 
-       cma_make_mc_event(status, id_priv, multicast, &event, mc);
-       ret = cma_cm_event_handler(id_priv, &event);
+       ret = cma_set_qkey(id_priv, be32_to_cpu(multicast->rec.qkey));
+       if (!ret) {
+               cma_make_mc_event(status, id_priv, multicast, &event, mc);
+               ret = cma_cm_event_handler(id_priv, &event);
+       }
        rdma_destroy_ah_attr(&event.param.ud.ah_attr);
        WARN_ON(ret);
 
@@ -4877,9 +4880,11 @@ static int cma_join_ib_multicast(struct rdma_id_private *id_priv,
        if (ret)
                return ret;
 
-       ret = cma_set_qkey(id_priv, 0);
-       if (ret)
-               return ret;
+       if (!id_priv->qkey) {
+               ret = cma_set_default_qkey(id_priv);
+               if (ret)
+                       return ret;
+       }
 
        cma_set_mgid(id_priv, (struct sockaddr *) &mc->addr, &rec.mgid);
        rec.qkey = cpu_to_be32(id_priv->qkey);
@@ -4956,9 +4961,6 @@ static int cma_iboe_join_multicast(struct rdma_id_private *id_priv,
        cma_iboe_set_mgid(addr, &ib.rec.mgid, gid_type);
 
        ib.rec.pkey = cpu_to_be16(0xffff);
-       if (id_priv->id.ps == RDMA_PS_UDP)
-               ib.rec.qkey = cpu_to_be32(RDMA_UDP_QKEY);
-
        if (dev_addr->bound_dev_if)
                ndev = dev_get_by_index(dev_addr->net, dev_addr->bound_dev_if);
        if (!ndev)
@@ -4984,6 +4986,9 @@ static int cma_iboe_join_multicast(struct rdma_id_private *id_priv,
        if (err || !ib.rec.mtu)
                return err ?: -EINVAL;
 
+       if (!id_priv->qkey)
+               cma_set_default_qkey(id_priv);
+
        rdma_ip2gid((struct sockaddr *)&id_priv->id.route.addr.src_addr,
                    &ib.rec.port_gid);
        INIT_WORK(&mc->iboe_join.work, cma_iboe_join_work_handler);
@@ -5009,6 +5014,9 @@ int rdma_join_multicast(struct rdma_cm_id *id, struct sockaddr *addr,
                            READ_ONCE(id_priv->state) != RDMA_CM_ADDR_RESOLVED))
                return -EINVAL;
 
+       if (id_priv->id.qp_type != IB_QPT_UD)
+               return -EINVAL;
+
        mc = kzalloc(sizeof(*mc), GFP_KERNEL);
        if (!mc)
                return -ENOMEM;