Merge tag 'clk-fixes-for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git...
[sfrench/cifs-2.6.git] / tools / testing / selftests / bpf / test_sock_addr.c
index 73b7493d4120991527b61a2a33a9c0784542176b..3f110eaaf29cea214844ff98211697c764b8a870 100644 (file)
@@ -44,6 +44,7 @@
 #define SERV6_V4MAPPED_IP      "::ffff:192.168.0.4"
 #define SRC6_IP                        "::1"
 #define SRC6_REWRITE_IP                "::6"
+#define WILDCARD6_IP           "::"
 #define SERV6_PORT             6060
 #define SERV6_REWRITE_PORT     6666
 
@@ -85,12 +86,14 @@ static int bind4_prog_load(const struct sock_addr_test *test);
 static int bind6_prog_load(const struct sock_addr_test *test);
 static int connect4_prog_load(const struct sock_addr_test *test);
 static int connect6_prog_load(const struct sock_addr_test *test);
+static int sendmsg_allow_prog_load(const struct sock_addr_test *test);
 static int sendmsg_deny_prog_load(const struct sock_addr_test *test);
 static int sendmsg4_rw_asm_prog_load(const struct sock_addr_test *test);
 static int sendmsg4_rw_c_prog_load(const struct sock_addr_test *test);
 static int sendmsg6_rw_asm_prog_load(const struct sock_addr_test *test);
 static int sendmsg6_rw_c_prog_load(const struct sock_addr_test *test);
 static int sendmsg6_rw_v4mapped_prog_load(const struct sock_addr_test *test);
+static int sendmsg6_rw_wildcard_prog_load(const struct sock_addr_test *test);
 
 static struct sock_addr_test tests[] = {
        /* bind */
@@ -462,6 +465,34 @@ static struct sock_addr_test tests[] = {
                SRC6_REWRITE_IP,
                SYSCALL_ENOTSUPP,
        },
+       {
+               "sendmsg6: set dst IP = [::] (BSD'ism)",
+               sendmsg6_rw_wildcard_prog_load,
+               BPF_CGROUP_UDP6_SENDMSG,
+               BPF_CGROUP_UDP6_SENDMSG,
+               AF_INET6,
+               SOCK_DGRAM,
+               SERV6_IP,
+               SERV6_PORT,
+               SERV6_REWRITE_IP,
+               SERV6_REWRITE_PORT,
+               SRC6_REWRITE_IP,
+               SUCCESS,
+       },
+       {
+               "sendmsg6: preserve dst IP = [::] (BSD'ism)",
+               sendmsg_allow_prog_load,
+               BPF_CGROUP_UDP6_SENDMSG,
+               BPF_CGROUP_UDP6_SENDMSG,
+               AF_INET6,
+               SOCK_DGRAM,
+               WILDCARD6_IP,
+               SERV6_PORT,
+               SERV6_REWRITE_IP,
+               SERV6_PORT,
+               SRC6_IP,
+               SUCCESS,
+       },
        {
                "sendmsg6: deny call",
                sendmsg_deny_prog_load,
@@ -734,16 +765,27 @@ static int connect6_prog_load(const struct sock_addr_test *test)
        return load_path(test, CONNECT6_PROG_PATH);
 }
 
-static int sendmsg_deny_prog_load(const struct sock_addr_test *test)
+static int sendmsg_ret_only_prog_load(const struct sock_addr_test *test,
+                                     int32_t rc)
 {
        struct bpf_insn insns[] = {
-               /* return 0 */
-               BPF_MOV64_IMM(BPF_REG_0, 0),
+               /* return rc */
+               BPF_MOV64_IMM(BPF_REG_0, rc),
                BPF_EXIT_INSN(),
        };
        return load_insns(test, insns, sizeof(insns) / sizeof(struct bpf_insn));
 }
 
+static int sendmsg_allow_prog_load(const struct sock_addr_test *test)
+{
+       return sendmsg_ret_only_prog_load(test, /*rc*/ 1);
+}
+
+static int sendmsg_deny_prog_load(const struct sock_addr_test *test)
+{
+       return sendmsg_ret_only_prog_load(test, /*rc*/ 0);
+}
+
 static int sendmsg4_rw_asm_prog_load(const struct sock_addr_test *test)
 {
        struct sockaddr_in dst4_rw_addr;
@@ -864,6 +906,11 @@ static int sendmsg6_rw_v4mapped_prog_load(const struct sock_addr_test *test)
        return sendmsg6_rw_dst_asm_prog_load(test, SERV6_V4MAPPED_IP);
 }
 
+static int sendmsg6_rw_wildcard_prog_load(const struct sock_addr_test *test)
+{
+       return sendmsg6_rw_dst_asm_prog_load(test, WILDCARD6_IP);
+}
+
 static int sendmsg6_rw_c_prog_load(const struct sock_addr_test *test)
 {
        return load_path(test, SENDMSG6_PROG_PATH);
@@ -1395,7 +1442,7 @@ int main(int argc, char **argv)
                goto err;
 
        cgfd = create_and_get_cgroup(CG_PATH);
-       if (!cgfd)
+       if (cgfd < 0)
                goto err;
 
        if (join_cgroup(CG_PATH))