diff --git a/net/mptcp/bpf.c b/net/mptcp/bpf.c index 5a0a84ad94af..8a16672b94e2 100644 --- a/net/mptcp/bpf.c +++ b/net/mptcp/bpf.c @@ -19,3 +19,18 @@ struct mptcp_sock *bpf_mptcp_sock_from_subflow(struct sock *sk) return NULL; } + +BTF_SET8_START(bpf_mptcp_fmodret_ids) +BTF_ID_FLAGS(func, update_socket_protocol) +BTF_SET8_END(bpf_mptcp_fmodret_ids) + +static const struct btf_kfunc_id_set bpf_mptcp_fmodret_set = { + .owner = THIS_MODULE, + .set = &bpf_mptcp_fmodret_ids, +}; + +static int __init bpf_mptcp_kfunc_init(void) +{ + return register_btf_fmodret_id_set(&bpf_mptcp_fmodret_set); +} +late_initcall(bpf_mptcp_kfunc_init); diff --git a/net/socket.c b/net/socket.c index 5d4e37595e9a..fdb5233bf560 100644 --- a/net/socket.c +++ b/net/socket.c @@ -1657,12 +1657,36 @@ struct file *__sys_socket_file(int family, int type, int protocol) return sock_alloc_file(sock, flags, NULL); } +/* A hook for bpf progs to attach to and update socket protocol. + * + * A static noinline declaration here could cause the compiler to + * optimize away the function. A global noinline declaration will + * keep the definition, but may optimize away the callsite. + * Therefore, __weak is needed to ensure that the call is still + * emitted, by telling the compiler that we don't know what the + * function might eventually be. + * + * __diag_* below are needed to dismiss the missing prototype warning. + */ + +__diag_push(); +__diag_ignore_all("-Wmissing-prototypes", + "A fmod_ret entry point for BPF programs"); + +__weak noinline int update_socket_protocol(int family, int type, int protocol) +{ + return protocol; +} + +__diag_pop(); + int __sys_socket(int family, int type, int protocol) { struct socket *sock; int flags; - sock = __sys_socket_create(family, type, protocol); + sock = __sys_socket_create(family, type, + update_socket_protocol(family, type, protocol)); if (IS_ERR(sock)) return PTR_ERR(sock); diff --git a/tools/testing/selftests/bpf/prog_tests/mptcp.c b/tools/testing/selftests/bpf/prog_tests/mptcp.c index cd0c42fff7c0..7c0be7cf550b 100644 --- a/tools/testing/selftests/bpf/prog_tests/mptcp.c +++ b/tools/testing/selftests/bpf/prog_tests/mptcp.c @@ -2,17 +2,59 @@ /* Copyright (c) 2020, Tessares SA. */ /* Copyright (c) 2022, SUSE. */ +#include +#include #include #include "cgroup_helpers.h" #include "network_helpers.h" #include "mptcp_sock.skel.h" +#include "mptcpify.skel.h" #define NS_TEST "mptcp_ns" +#ifndef IPPROTO_MPTCP +#define IPPROTO_MPTCP 262 +#endif + +#ifndef SOL_MPTCP +#define SOL_MPTCP 284 +#endif +#ifndef MPTCP_INFO +#define MPTCP_INFO 1 +#endif +#ifndef MPTCP_INFO_FLAG_FALLBACK +#define MPTCP_INFO_FLAG_FALLBACK _BITUL(0) +#endif +#ifndef MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED +#define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED _BITUL(1) +#endif + #ifndef TCP_CA_NAME_MAX #define TCP_CA_NAME_MAX 16 #endif +struct __mptcp_info { + __u8 mptcpi_subflows; + __u8 mptcpi_add_addr_signal; + __u8 mptcpi_add_addr_accepted; + __u8 mptcpi_subflows_max; + __u8 mptcpi_add_addr_signal_max; + __u8 mptcpi_add_addr_accepted_max; + __u32 mptcpi_flags; + __u32 mptcpi_token; + __u64 mptcpi_write_seq; + __u64 mptcpi_snd_una; + __u64 mptcpi_rcv_nxt; + __u8 mptcpi_local_addr_used; + __u8 mptcpi_local_addr_max; + __u8 mptcpi_csum_enabled; + __u32 mptcpi_retransmits; + __u64 mptcpi_bytes_retrans; + __u64 mptcpi_bytes_sent; + __u64 mptcpi_bytes_received; + __u64 mptcpi_bytes_acked; +}; + struct mptcp_storage { __u32 invoked; __u32 is_mptcp; @@ -22,6 +64,24 @@ struct mptcp_storage { char ca_name[TCP_CA_NAME_MAX]; }; +static struct nstoken *create_netns(void) +{ + SYS(fail, "ip netns add %s", NS_TEST); + SYS(fail, "ip -net %s link set dev lo up", NS_TEST); + + return open_netns(NS_TEST); +fail: + return NULL; +} + +static void cleanup_netns(struct nstoken *nstoken) +{ + if (nstoken) + close_netns(nstoken); + + SYS_NOFAIL("ip netns del %s &> /dev/null", NS_TEST); +} + static int verify_tsk(int map_fd, int client_fd) { int err, cfd = client_fd; @@ -100,24 +160,14 @@ static int run_test(int cgroup_fd, int server_fd, bool is_mptcp) sock_skel = mptcp_sock__open_and_load(); if (!ASSERT_OK_PTR(sock_skel, "skel_open_load")) - return -EIO; + return libbpf_get_error(sock_skel); err = mptcp_sock__attach(sock_skel); if (!ASSERT_OK(err, "skel_attach")) goto out; prog_fd = bpf_program__fd(sock_skel->progs._sockops); - if (!ASSERT_GE(prog_fd, 0, "bpf_program__fd")) { - err = -EIO; - goto out; - } - map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map); - if (!ASSERT_GE(map_fd, 0, "bpf_map__fd")) { - err = -EIO; - goto out; - } - err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0); if (!ASSERT_OK(err, "bpf_prog_attach")) goto out; @@ -147,11 +197,8 @@ static void test_base(void) if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup")) return; - SYS(fail, "ip netns add %s", NS_TEST); - SYS(fail, "ip -net %s link set dev lo up", NS_TEST); - - nstoken = open_netns(NS_TEST); - if (!ASSERT_OK_PTR(nstoken, "open_netns")) + nstoken = create_netns(); + if (!ASSERT_OK_PTR(nstoken, "create_netns")) goto fail; /* without MPTCP */ @@ -174,11 +221,104 @@ static void test_base(void) close(server_fd); fail: - if (nstoken) - close_netns(nstoken); + cleanup_netns(nstoken); + close(cgroup_fd); +} - SYS_NOFAIL("ip netns del " NS_TEST " &> /dev/null"); +static void send_byte(int fd) +{ + char b = 0x55; + ASSERT_EQ(write(fd, &b, sizeof(b)), 1, "send single byte"); +} + +static int verify_mptcpify(int server_fd, int client_fd) +{ + struct __mptcp_info info; + socklen_t optlen; + int protocol; + int err = 0; + + optlen = sizeof(protocol); + if (!ASSERT_OK(getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen), + "getsockopt(SOL_PROTOCOL)")) + return -1; + + if (!ASSERT_EQ(protocol, IPPROTO_MPTCP, "protocol isn't MPTCP")) + err++; + + optlen = sizeof(info); + if (!ASSERT_OK(getsockopt(client_fd, SOL_MPTCP, MPTCP_INFO, &info, &optlen), + "getsockopt(MPTCP_INFO)")) + return -1; + + if (!ASSERT_GE(info.mptcpi_flags, 0, "unexpected mptcpi_flags")) + err++; + if (!ASSERT_FALSE(info.mptcpi_flags & MPTCP_INFO_FLAG_FALLBACK, + "MPTCP fallback")) + err++; + if (!ASSERT_TRUE(info.mptcpi_flags & MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED, + "no remote key received")) + err++; + + return err; +} + +static int run_mptcpify(int cgroup_fd) +{ + int server_fd, client_fd, err = 0; + struct mptcpify *mptcpify_skel; + + mptcpify_skel = mptcpify__open_and_load(); + if (!ASSERT_OK_PTR(mptcpify_skel, "skel_open_load")) + return libbpf_get_error(mptcpify_skel); + + err = mptcpify__attach(mptcpify_skel); + if (!ASSERT_OK(err, "skel_attach")) + goto out; + + /* without MPTCP */ + server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0); + if (!ASSERT_GE(server_fd, 0, "start_server")) { + err = -EIO; + goto out; + } + + client_fd = connect_to_fd(server_fd, 0); + if (!ASSERT_GE(client_fd, 0, "connect to fd")) { + err = -EIO; + goto close_server; + } + + send_byte(client_fd); + + err = verify_mptcpify(server_fd, client_fd); + + close(client_fd); +close_server: + close(server_fd); +out: + mptcpify__destroy(mptcpify_skel); + return err; +} + +static void test_mptcpify(void) +{ + struct nstoken *nstoken = NULL; + int cgroup_fd; + + cgroup_fd = test__join_cgroup("/mptcpify"); + if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup")) + return; + + nstoken = create_netns(); + if (!ASSERT_OK_PTR(nstoken, "create_netns")) + goto fail; + + ASSERT_OK(run_mptcpify(cgroup_fd), "run_mptcpify"); + +fail: + cleanup_netns(nstoken); close(cgroup_fd); } @@ -186,4 +326,6 @@ void test_mptcp(void) { if (test__start_subtest("base")) test_base(); + if (test__start_subtest("mptcpify")) + test_mptcpify(); } diff --git a/tools/testing/selftests/bpf/progs/mptcpify.c b/tools/testing/selftests/bpf/progs/mptcpify.c new file mode 100644 index 000000000000..53301ae8a8f7 --- /dev/null +++ b/tools/testing/selftests/bpf/progs/mptcpify.c @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2023, SUSE. */ + +#include "vmlinux.h" +#include +#include "bpf_tracing_net.h" + +char _license[] SEC("license") = "GPL"; + +SEC("fmod_ret/update_socket_protocol") +int BPF_PROG(mptcpify, int family, int type, int protocol) +{ + if ((family == AF_INET || family == AF_INET6) && + type == SOCK_STREAM && + (!protocol || protocol == IPPROTO_TCP)) { + return IPPROTO_MPTCP; + } + + return protocol; +}