diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c index a46067c38bf5..0fad1331303c 100644 --- a/drivers/net/wireguard/device.c +++ b/drivers/net/wireguard/device.c @@ -59,9 +59,7 @@ out: return ret; } -#ifdef CONFIG_PM_SLEEP -static int wg_pm_notification(struct notifier_block *nb, unsigned long action, - void *data) +static int wg_pm_notification(struct notifier_block *nb, unsigned long action, void *data) { struct wg_device *wg; struct wg_peer *peer; @@ -92,7 +90,24 @@ static int wg_pm_notification(struct notifier_block *nb, unsigned long action, } static struct notifier_block pm_notifier = { .notifier_call = wg_pm_notification }; -#endif + +static int wg_vm_notification(struct notifier_block *nb, unsigned long action, void *data) +{ + struct wg_device *wg; + struct wg_peer *peer; + + rtnl_lock(); + list_for_each_entry(wg, &device_list, device_list) { + mutex_lock(&wg->device_update_lock); + list_for_each_entry(peer, &wg->peer_list, peer_list) + wg_noise_expire_current_peer_keypairs(peer); + mutex_unlock(&wg->device_update_lock); + } + rtnl_unlock(); + return 0; +} + +static struct notifier_block vm_notifier = { .notifier_call = wg_vm_notification }; static int wg_stop(struct net_device *dev) { @@ -424,15 +439,17 @@ int __init wg_device_init(void) { int ret; -#ifdef CONFIG_PM_SLEEP ret = register_pm_notifier(&pm_notifier); if (ret) return ret; -#endif + + ret = register_random_vmfork_notifier(&vm_notifier); + if (ret) + goto error_pm; ret = register_pernet_device(&pernet_ops); if (ret) - goto error_pm; + goto error_vm; ret = rtnl_link_register(&link_ops); if (ret) @@ -442,10 +459,10 @@ int __init wg_device_init(void) error_pernet: unregister_pernet_device(&pernet_ops); +error_vm: + unregister_random_vmfork_notifier(&vm_notifier); error_pm: -#ifdef CONFIG_PM_SLEEP unregister_pm_notifier(&pm_notifier); -#endif return ret; } @@ -453,8 +470,7 @@ void wg_device_uninit(void) { rtnl_link_unregister(&link_ops); unregister_pernet_device(&pernet_ops); -#ifdef CONFIG_PM_SLEEP + unregister_random_vmfork_notifier(&vm_notifier); unregister_pm_notifier(&pm_notifier); -#endif rcu_barrier(); }