From 420f7bc1cb937ae61ecd3d643955d68d8592e569 Mon Sep 17 00:00:00 2001 From: Sophie Tauchert Date: Tue, 10 Jan 2023 14:31:44 +0100 Subject: [PATCH] Resubscribe when subscription tasks raise an exception --- ntfy/bot.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/ntfy/bot.py b/ntfy/bot.py index 8d7e7eb..55ded00 100644 --- a/ntfy/bot.py +++ b/ntfy/bot.py @@ -1,7 +1,7 @@ import asyncio import html import json -from typing import Any, List, Tuple +from typing import Any, Dict, Tuple from aiohttp import ClientTimeout from maubot import MessageEvent, Plugin @@ -20,7 +20,7 @@ from .emoji import EMOJI_FALLBACK, WHITE_CHECK_MARK, parse_tags class NtfyBot(Plugin): db: DB config: Config - tasks: List[asyncio.Task] = [] + tasks: Dict[int, asyncio.Task] = {} async def start(self) -> None: await super().start() @@ -44,7 +44,7 @@ class NtfyBot(Plugin): await self.subscribe_to_topics() async def clear_subscriptions(self) -> None: - tasks = self.tasks[:] + tasks = list(self.tasks.values()) if not tasks: return None @@ -58,7 +58,7 @@ class NtfyBot(Plugin): pass except Exception as exc: self.log.exception("Subscription task errored", exc_info=exc) - self.tasks[:] = [] + self.tasks.clear() async def can_use_command(self, evt: MessageEvent) -> bool: if evt.sender in self.config["admins"]: @@ -117,15 +117,18 @@ class NtfyBot(Plugin): async def subscribe_to_topics(self) -> None: topics = await self.db.get_topics() for topic in topics: - await self.subscribe_to_topic(topic) + self.subscribe_to_topic(topic) - async def subscribe_to_topic(self, topic: Topic) -> None: + def subscribe_to_topic(self, topic: Topic) -> None: def log_task_exc(task: asyncio.Task) -> None: + t2 = self.tasks.pop(topic.id, None) + if t2 != task: + self.log.warn("stored task doesn't match callback") if task.done() and not task.cancelled(): exc = task.exception() self.log.exception( - "Subscription task errored", exc_info=exc) - # TODO: restart subscription# + "Subscription task errored, resubscribing", exc_info=exc) + self.subscribe_to_topic(topic) self.log.info("Subscribing to %s/%s", topic.server, topic.topic) url = "%s/%s/json" % (topic.server, topic.topic) @@ -137,8 +140,7 @@ class NtfyBot(Plugin): self.log.debug("Subscribing to URL %s", url) task = self.loop.create_task( self.run_topic_subscription(topic, url)) - self.tasks.append(task) - task.add_done_callback(self.tasks.remove) + self.tasks[topic.id] = task task.add_done_callback(log_task_exc) async def run_topic_subscription(self, topic: Topic, url: str) -> None: