Merge https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next
[sfrench/cifs-2.6.git] / net / core / sock.c
index 788c1372663cbabdd3d2dd0d0274d60b7d63dd2c..eeb6cbac6f4998dbc41fc686e7e882135e45b9e3 100644 (file)
@@ -703,15 +703,17 @@ static int sock_setbindtodevice(struct sock *sk, sockptr_t optval, int optlen)
                        goto out;
        }
 
-       return sock_bindtoindex(sk, index, true);
+       sockopt_lock_sock(sk);
+       ret = sock_bindtoindex_locked(sk, index);
+       sockopt_release_sock(sk);
 out:
 #endif
 
        return ret;
 }
 
-static int sock_getbindtodevice(struct sock *sk, char __user *optval,
-                               int __user *optlen, int len)
+static int sock_getbindtodevice(struct sock *sk, sockptr_t optval,
+                               sockptr_t optlen, int len)
 {
        int ret = -ENOPROTOOPT;
 #ifdef CONFIG_NETDEVICES
@@ -735,12 +737,12 @@ static int sock_getbindtodevice(struct sock *sk, char __user *optval,
        len = strlen(devname) + 1;
 
        ret = -EFAULT;
-       if (copy_to_user(optval, devname, len))
+       if (copy_to_sockptr(optval, devname, len))
                goto out;
 
 zero:
        ret = -EFAULT;
-       if (put_user(len, optlen))
+       if (copy_to_sockptr(optlen, &len, sizeof(int)))
                goto out;
 
        ret = 0;
@@ -1036,17 +1038,51 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
        return 0;
 }
 
+void sockopt_lock_sock(struct sock *sk)
+{
+       /* When current->bpf_ctx is set, the setsockopt is called from
+        * a bpf prog.  bpf has ensured the sk lock has been
+        * acquired before calling setsockopt().
+        */
+       if (has_current_bpf_ctx())
+               return;
+
+       lock_sock(sk);
+}
+EXPORT_SYMBOL(sockopt_lock_sock);
+
+void sockopt_release_sock(struct sock *sk)
+{
+       if (has_current_bpf_ctx())
+               return;
+
+       release_sock(sk);
+}
+EXPORT_SYMBOL(sockopt_release_sock);
+
+bool sockopt_ns_capable(struct user_namespace *ns, int cap)
+{
+       return has_current_bpf_ctx() || ns_capable(ns, cap);
+}
+EXPORT_SYMBOL(sockopt_ns_capable);
+
+bool sockopt_capable(int cap)
+{
+       return has_current_bpf_ctx() || capable(cap);
+}
+EXPORT_SYMBOL(sockopt_capable);
+
 /*
  *     This is meant for all protocols to use and covers goings on
  *     at the socket level. Everything here is generic.
  */
 
-int sock_setsockopt(struct socket *sock, int level, int optname,
-                   sockptr_t optval, unsigned int optlen)
+int sk_setsockopt(struct sock *sk, int level, int optname,
+                 sockptr_t optval, unsigned int optlen)
 {
        struct so_timestamping timestamping;
+       struct socket *sock = sk->sk_socket;
        struct sock_txtime sk_txtime;
-       struct sock *sk = sock->sk;
        int val;
        int valbool;
        struct linger ling;
@@ -1067,11 +1103,11 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
 
        valbool = val ? 1 : 0;
 
-       lock_sock(sk);
+       sockopt_lock_sock(sk);
 
        switch (optname) {
        case SO_DEBUG:
-               if (val && !capable(CAP_NET_ADMIN))
+               if (val && !sockopt_capable(CAP_NET_ADMIN))
                        ret = -EACCES;
                else
                        sock_valbool_flag(sk, SOCK_DBG, valbool);
@@ -1115,7 +1151,7 @@ set_sndbuf:
                break;
 
        case SO_SNDBUFFORCE:
-               if (!capable(CAP_NET_ADMIN)) {
+               if (!sockopt_capable(CAP_NET_ADMIN)) {
                        ret = -EPERM;
                        break;
                }
@@ -1137,7 +1173,7 @@ set_sndbuf:
                break;
 
        case SO_RCVBUFFORCE:
-               if (!capable(CAP_NET_ADMIN)) {
+               if (!sockopt_capable(CAP_NET_ADMIN)) {
                        ret = -EPERM;
                        break;
                }
@@ -1164,8 +1200,8 @@ set_sndbuf:
 
        case SO_PRIORITY:
                if ((val >= 0 && val <= 6) ||
-                   ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) ||
-                   ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN))
+                   sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) ||
+                   sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN))
                        sk->sk_priority = val;
                else
                        ret = -EPERM;
@@ -1228,7 +1264,7 @@ set_sndbuf:
        case SO_RCVLOWAT:
                if (val < 0)
                        val = INT_MAX;
-               if (sock->ops->set_rcvlowat)
+               if (sock && sock->ops->set_rcvlowat)
                        ret = sock->ops->set_rcvlowat(sk, val);
                else
                        WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
@@ -1310,8 +1346,8 @@ set_sndbuf:
                        clear_bit(SOCK_PASSSEC, &sock->flags);
                break;
        case SO_MARK:
-               if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) &&
-                   !ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
+               if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) &&
+                   !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
                        ret = -EPERM;
                        break;
                }
@@ -1319,8 +1355,8 @@ set_sndbuf:
                __sock_set_mark(sk, val);
                break;
        case SO_RCVMARK:
-               if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) &&
-                   !ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
+               if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) &&
+                   !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
                        ret = -EPERM;
                        break;
                }
@@ -1354,7 +1390,7 @@ set_sndbuf:
 #ifdef CONFIG_NET_RX_BUSY_POLL
        case SO_BUSY_POLL:
                /* allow unprivileged users to decrease the value */
-               if ((val > sk->sk_ll_usec) && !capable(CAP_NET_ADMIN))
+               if ((val > sk->sk_ll_usec) && !sockopt_capable(CAP_NET_ADMIN))
                        ret = -EPERM;
                else {
                        if (val < 0)
@@ -1364,13 +1400,13 @@ set_sndbuf:
                }
                break;
        case SO_PREFER_BUSY_POLL:
-               if (valbool && !capable(CAP_NET_ADMIN))
+               if (valbool && !sockopt_capable(CAP_NET_ADMIN))
                        ret = -EPERM;
                else
                        WRITE_ONCE(sk->sk_prefer_busy_poll, valbool);
                break;
        case SO_BUSY_POLL_BUDGET:
-               if (val > READ_ONCE(sk->sk_busy_poll_budget) && !capable(CAP_NET_ADMIN)) {
+               if (val > READ_ONCE(sk->sk_busy_poll_budget) && !sockopt_capable(CAP_NET_ADMIN)) {
                        ret = -EPERM;
                } else {
                        if (val < 0 || val > U16_MAX)
@@ -1441,7 +1477,7 @@ set_sndbuf:
                 * scheduler has enough safe guards.
                 */
                if (sk_txtime.clockid != CLOCK_MONOTONIC &&
-                   !ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
+                   !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
                        ret = -EPERM;
                        break;
                }
@@ -1496,9 +1532,16 @@ set_sndbuf:
                ret = -ENOPROTOOPT;
                break;
        }
-       release_sock(sk);
+       sockopt_release_sock(sk);
        return ret;
 }
+
+int sock_setsockopt(struct socket *sock, int level, int optname,
+                   sockptr_t optval, unsigned int optlen)
+{
+       return sk_setsockopt(sock->sk, level, optname,
+                            optval, optlen);
+}
 EXPORT_SYMBOL(sock_setsockopt);
 
 static const struct cred *sk_get_peer_cred(struct sock *sk)
@@ -1525,22 +1568,25 @@ static void cred_to_ucred(struct pid *pid, const struct cred *cred,
        }
 }
 
-static int groups_to_user(gid_t __user *dst, const struct group_info *src)
+static int groups_to_user(sockptr_t dst, const struct group_info *src)
 {
        struct user_namespace *user_ns = current_user_ns();
        int i;
 
-       for (i = 0; i < src->ngroups; i++)
-               if (put_user(from_kgid_munged(user_ns, src->gid[i]), dst + i))
+       for (i = 0; i < src->ngroups; i++) {
+               gid_t gid = from_kgid_munged(user_ns, src->gid[i]);
+
+               if (copy_to_sockptr_offset(dst, i * sizeof(gid), &gid, sizeof(gid)))
                        return -EFAULT;
+       }
 
        return 0;
 }
 
-int sock_getsockopt(struct socket *sock, int level, int optname,
-                   char __user *optval, int __user *optlen)
+int sk_getsockopt(struct sock *sk, int level, int optname,
+                 sockptr_t optval, sockptr_t optlen)
 {
-       struct sock *sk = sock->sk;
+       struct socket *sock = sk->sk_socket;
 
        union {
                int val;
@@ -1557,7 +1603,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
        int lv = sizeof(int);
        int len;
 
-       if (get_user(len, optlen))
+       if (copy_from_sockptr(&len, optlen, sizeof(int)))
                return -EFAULT;
        if (len < 0)
                return -EINVAL;
@@ -1692,7 +1738,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
                cred_to_ucred(sk->sk_peer_pid, sk->sk_peer_cred, &peercred);
                spin_unlock(&sk->sk_peer_lock);
 
-               if (copy_to_user(optval, &peercred, len))
+               if (copy_to_sockptr(optval, &peercred, len))
                        return -EFAULT;
                goto lenout;
        }
@@ -1710,11 +1756,11 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
                if (len < n * sizeof(gid_t)) {
                        len = n * sizeof(gid_t);
                        put_cred(cred);
-                       return put_user(len, optlen) ? -EFAULT : -ERANGE;
+                       return copy_to_sockptr(optlen, &len, sizeof(int)) ? -EFAULT : -ERANGE;
                }
                len = n * sizeof(gid_t);
 
-               ret = groups_to_user((gid_t __user *)optval, cred->group_info);
+               ret = groups_to_user(optval, cred->group_info);
                put_cred(cred);
                if (ret)
                        return ret;
@@ -1730,7 +1776,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
                        return -ENOTCONN;
                if (lv < len)
                        return -EINVAL;
-               if (copy_to_user(optval, address, len))
+               if (copy_to_sockptr(optval, address, len))
                        return -EFAULT;
                goto lenout;
        }
@@ -1747,7 +1793,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
                break;
 
        case SO_PEERSEC:
-               return security_socket_getpeersec_stream(sock, optval, optlen, len);
+               return security_socket_getpeersec_stream(sock, optval.user, optlen.user, len);
 
        case SO_MARK:
                v.val = sk->sk_mark;
@@ -1779,7 +1825,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
                return sock_getbindtodevice(sk, optval, optlen, len);
 
        case SO_GET_FILTER:
-               len = sk_get_filter(sk, (struct sock_filter __user *)optval, len);
+               len = sk_get_filter(sk, optval, len);
                if (len < 0)
                        return len;
 
@@ -1827,7 +1873,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
                sk_get_meminfo(sk, meminfo);
 
                len = min_t(unsigned int, len, sizeof(meminfo));
-               if (copy_to_user(optval, &meminfo, len))
+               if (copy_to_sockptr(optval, &meminfo, len))
                        return -EFAULT;
 
                goto lenout;
@@ -1896,14 +1942,22 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
 
        if (len > lv)
                len = lv;
-       if (copy_to_user(optval, &v, len))
+       if (copy_to_sockptr(optval, &v, len))
                return -EFAULT;
 lenout:
-       if (put_user(len, optlen))
+       if (copy_to_sockptr(optlen, &len, sizeof(int)))
                return -EFAULT;
        return 0;
 }
 
+int sock_getsockopt(struct socket *sock, int level, int optname,
+                   char __user *optval, int __user *optlen)
+{
+       return sk_getsockopt(sock->sk, level, optname,
+                            USER_SOCKPTR(optval),
+                            USER_SOCKPTR(optlen));
+}
+
 /*
  * Initialize an sk_lock.
  *