diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c index 65a937ed5762..ac09b35a9567 100644 --- a/kernel/bpf/sockmap.c +++ b/kernel/bpf/sockmap.c @@ -72,6 +72,7 @@ struct bpf_htab { u32 n_buckets; u32 elem_size; struct bpf_sock_progs progs; + struct rcu_head rcu; }; struct htab_elem { @@ -89,8 +90,8 @@ enum smap_psock_state { struct smap_psock_map_entry { struct list_head list; struct sock **entry; - struct htab_elem *hash_link; - struct bpf_htab *htab; + struct htab_elem __rcu *hash_link; + struct bpf_htab __rcu *htab; }; struct smap_psock { @@ -120,6 +121,7 @@ struct smap_psock { struct bpf_prog *bpf_parse; struct bpf_prog *bpf_verdict; struct list_head maps; + spinlock_t maps_lock; /* Back reference used when sock callback trigger sockmap operations */ struct sock *sock; @@ -258,16 +260,54 @@ out: rcu_read_unlock(); } +static struct htab_elem *lookup_elem_raw(struct hlist_head *head, + u32 hash, void *key, u32 key_size) +{ + struct htab_elem *l; + + hlist_for_each_entry_rcu(l, head, hash_node) { + if (l->hash == hash && !memcmp(&l->key, key, key_size)) + return l; + } + + return NULL; +} + +static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash) +{ + return &htab->buckets[hash & (htab->n_buckets - 1)]; +} + +static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash) +{ + return &__select_bucket(htab, hash)->head; +} + static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l) { atomic_dec(&htab->count); kfree_rcu(l, rcu); } +static struct smap_psock_map_entry *psock_map_pop(struct sock *sk, + struct smap_psock *psock) +{ + struct smap_psock_map_entry *e; + + spin_lock_bh(&psock->maps_lock); + e = list_first_entry_or_null(&psock->maps, + struct smap_psock_map_entry, + list); + if (e) + list_del(&e->list); + spin_unlock_bh(&psock->maps_lock); + return e; +} + static void bpf_tcp_close(struct sock *sk, long timeout) { void (*close_fun)(struct sock *sk, long timeout); - struct smap_psock_map_entry *e, *tmp; + struct smap_psock_map_entry *e; struct sk_msg_buff *md, *mtmp; struct smap_psock *psock; struct sock *osk; @@ -286,7 +326,6 @@ static void bpf_tcp_close(struct sock *sk, long timeout) */ close_fun = psock->save_close; - write_lock_bh(&sk->sk_callback_lock); if (psock->cork) { free_start_sg(psock->sock, psock->cork); kfree(psock->cork); @@ -299,20 +338,38 @@ static void bpf_tcp_close(struct sock *sk, long timeout) kfree(md); } - list_for_each_entry_safe(e, tmp, &psock->maps, list) { + e = psock_map_pop(sk, psock); + while (e) { if (e->entry) { osk = cmpxchg(e->entry, sk, NULL); if (osk == sk) { - list_del(&e->list); smap_release_sock(psock, sk); } } else { - hlist_del_rcu(&e->hash_link->hash_node); - smap_release_sock(psock, e->hash_link->sk); - free_htab_elem(e->htab, e->hash_link); + struct htab_elem *link = rcu_dereference(e->hash_link); + struct bpf_htab *htab = rcu_dereference(e->htab); + struct hlist_head *head; + struct htab_elem *l; + struct bucket *b; + + b = __select_bucket(htab, link->hash); + head = &b->head; + raw_spin_lock_bh(&b->lock); + l = lookup_elem_raw(head, + link->hash, link->key, + htab->map.key_size); + /* If another thread deleted this object skip deletion. + * The refcnt on psock may or may not be zero. + */ + if (l) { + hlist_del_rcu(&link->hash_node); + smap_release_sock(psock, link->sk); + free_htab_elem(htab, link); + } + raw_spin_unlock_bh(&b->lock); } + e = psock_map_pop(sk, psock); } - write_unlock_bh(&sk->sk_callback_lock); rcu_read_unlock(); close_fun(sk, timeout); } @@ -1395,7 +1452,9 @@ static void smap_release_sock(struct smap_psock *psock, struct sock *sock) { if (refcount_dec_and_test(&psock->refcnt)) { tcp_cleanup_ulp(sock); + write_lock_bh(&sock->sk_callback_lock); smap_stop_sock(psock, sock); + write_unlock_bh(&sock->sk_callback_lock); clear_bit(SMAP_TX_RUNNING, &psock->state); rcu_assign_sk_user_data(sock, NULL); call_rcu_sched(&psock->rcu, smap_destroy_psock); @@ -1546,6 +1605,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock, int node) INIT_LIST_HEAD(&psock->maps); INIT_LIST_HEAD(&psock->ingress); refcount_set(&psock->refcnt, 1); + spin_lock_init(&psock->maps_lock); rcu_assign_sk_user_data(sock, psock); sock_hold(sock); @@ -1607,10 +1667,12 @@ static void smap_list_map_remove(struct smap_psock *psock, { struct smap_psock_map_entry *e, *tmp; + spin_lock_bh(&psock->maps_lock); list_for_each_entry_safe(e, tmp, &psock->maps, list) { if (e->entry == entry) list_del(&e->list); } + spin_unlock_bh(&psock->maps_lock); } static void smap_list_hash_remove(struct smap_psock *psock, @@ -1618,12 +1680,14 @@ static void smap_list_hash_remove(struct smap_psock *psock, { struct smap_psock_map_entry *e, *tmp; + spin_lock_bh(&psock->maps_lock); list_for_each_entry_safe(e, tmp, &psock->maps, list) { - struct htab_elem *c = e->hash_link; + struct htab_elem *c = rcu_dereference(e->hash_link); if (c == hash_link) list_del(&e->list); } + spin_unlock_bh(&psock->maps_lock); } static void sock_map_free(struct bpf_map *map) @@ -1649,7 +1713,6 @@ static void sock_map_free(struct bpf_map *map) if (!sock) continue; - write_lock_bh(&sock->sk_callback_lock); psock = smap_psock_sk(sock); /* This check handles a racing sock event that can get the * sk_callback_lock before this case but after xchg happens @@ -1660,7 +1723,6 @@ static void sock_map_free(struct bpf_map *map) smap_list_map_remove(psock, &stab->sock_map[i]); smap_release_sock(psock, sock); } - write_unlock_bh(&sock->sk_callback_lock); } rcu_read_unlock(); @@ -1709,7 +1771,6 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key) if (!sock) return -EINVAL; - write_lock_bh(&sock->sk_callback_lock); psock = smap_psock_sk(sock); if (!psock) goto out; @@ -1719,7 +1780,6 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key) smap_list_map_remove(psock, &stab->sock_map[k]); smap_release_sock(psock, sock); out: - write_unlock_bh(&sock->sk_callback_lock); return 0; } @@ -1800,7 +1860,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map, } } - write_lock_bh(&sock->sk_callback_lock); psock = smap_psock_sk(sock); /* 2. Do not allow inheriting programs if psock exists and has @@ -1857,7 +1916,9 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map, if (err) goto out_free; smap_init_progs(psock, verdict, parse); + write_lock_bh(&sock->sk_callback_lock); smap_start_sock(psock, sock); + write_unlock_bh(&sock->sk_callback_lock); } /* 4. Place psock in sockmap for use and stop any programs on @@ -1867,9 +1928,10 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map, */ if (map_link) { e->entry = map_link; + spin_lock_bh(&psock->maps_lock); list_add_tail(&e->list, &psock->maps); + spin_unlock_bh(&psock->maps_lock); } - write_unlock_bh(&sock->sk_callback_lock); return err; out_free: smap_release_sock(psock, sock); @@ -1880,7 +1942,6 @@ out_progs: } if (tx_msg) bpf_prog_put(tx_msg); - write_unlock_bh(&sock->sk_callback_lock); kfree(e); return err; } @@ -1917,10 +1978,8 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops, if (osock) { struct smap_psock *opsock = smap_psock_sk(osock); - write_lock_bh(&osock->sk_callback_lock); smap_list_map_remove(opsock, &stab->sock_map[i]); smap_release_sock(opsock, osock); - write_unlock_bh(&osock->sk_callback_lock); } out: return err; @@ -2109,14 +2168,13 @@ free_htab: return ERR_PTR(err); } -static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash) +static void __bpf_htab_free(struct rcu_head *rcu) { - return &htab->buckets[hash & (htab->n_buckets - 1)]; -} + struct bpf_htab *htab; -static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash) -{ - return &__select_bucket(htab, hash)->head; + htab = container_of(rcu, struct bpf_htab, rcu); + bpf_map_area_free(htab->buckets); + kfree(htab); } static void sock_hash_free(struct bpf_map *map) @@ -2135,16 +2193,18 @@ static void sock_hash_free(struct bpf_map *map) */ rcu_read_lock(); for (i = 0; i < htab->n_buckets; i++) { - struct hlist_head *head = select_bucket(htab, i); + struct bucket *b = __select_bucket(htab, i); + struct hlist_head *head; struct hlist_node *n; struct htab_elem *l; + raw_spin_lock_bh(&b->lock); + head = &b->head; hlist_for_each_entry_safe(l, n, head, hash_node) { struct sock *sock = l->sk; struct smap_psock *psock; hlist_del_rcu(&l->hash_node); - write_lock_bh(&sock->sk_callback_lock); psock = smap_psock_sk(sock); /* This check handles a racing sock event that can get * the sk_callback_lock before this case but after xchg @@ -2155,13 +2215,12 @@ static void sock_hash_free(struct bpf_map *map) smap_list_hash_remove(psock, l); smap_release_sock(psock, sock); } - write_unlock_bh(&sock->sk_callback_lock); - kfree(l); + free_htab_elem(htab, l); } + raw_spin_unlock_bh(&b->lock); } rcu_read_unlock(); - bpf_map_area_free(htab->buckets); - kfree(htab); + call_rcu(&htab->rcu, __bpf_htab_free); } static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab, @@ -2188,19 +2247,6 @@ static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab, return l_new; } -static struct htab_elem *lookup_elem_raw(struct hlist_head *head, - u32 hash, void *key, u32 key_size) -{ - struct htab_elem *l; - - hlist_for_each_entry_rcu(l, head, hash_node) { - if (l->hash == hash && !memcmp(&l->key, key, key_size)) - return l; - } - - return NULL; -} - static inline u32 htab_map_hash(const void *key, u32 key_len) { return jhash(key, key_len, 0); @@ -2320,9 +2366,12 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops, goto bucket_err; } - e->hash_link = l_new; - e->htab = container_of(map, struct bpf_htab, map); + rcu_assign_pointer(e->hash_link, l_new); + rcu_assign_pointer(e->htab, + container_of(map, struct bpf_htab, map)); + spin_lock_bh(&psock->maps_lock); list_add_tail(&e->list, &psock->maps); + spin_unlock_bh(&psock->maps_lock); /* add new element to the head of the list, so that * concurrent search will find it before old elem @@ -2392,7 +2441,6 @@ static int sock_hash_delete_elem(struct bpf_map *map, void *key) struct smap_psock *psock; hlist_del_rcu(&l->hash_node); - write_lock_bh(&sock->sk_callback_lock); psock = smap_psock_sk(sock); /* This check handles a racing sock event that can get the * sk_callback_lock before this case but after xchg happens @@ -2403,7 +2451,6 @@ static int sock_hash_delete_elem(struct bpf_map *map, void *key) smap_list_hash_remove(psock, l); smap_release_sock(psock, sock); } - write_unlock_bh(&sock->sk_callback_lock); free_htab_elem(htab, l); ret = 0; }