From 704d5428a4434bfc3062a2b69ab7ebe2d3fe6cc0 Mon Sep 17 00:00:00 2001 From: Sophie Tauchert Date: Sun, 8 Jan 2023 14:54:27 +0100 Subject: [PATCH] Initial commit --- .gitignore | 2 + base-config.yaml | 1 + maubot.yaml | 14 ++++ ntfy/__init__.py | 1 + ntfy/bot.py | 196 +++++++++++++++++++++++++++++++++++++++++++++++ ntfy/config.py | 6 ++ ntfy/db.py | 179 +++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 399 insertions(+) create mode 100644 .gitignore create mode 100644 base-config.yaml create mode 100644 maubot.yaml create mode 100644 ntfy/__init__.py create mode 100644 ntfy/bot.py create mode 100644 ntfy/config.py create mode 100644 ntfy/db.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c87ce24 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.venv +*.mbp diff --git a/base-config.yaml b/base-config.yaml new file mode 100644 index 0000000..b2adc34 --- /dev/null +++ b/base-config.yaml @@ -0,0 +1 @@ +command_prefix: ntfy diff --git a/maubot.yaml b/maubot.yaml new file mode 100644 index 0000000..598892a --- /dev/null +++ b/maubot.yaml @@ -0,0 +1,14 @@ +maubot: 0.3.0 +id: cloud.catgirl.ntfy +version: 0.1.0 +license: AGPL-3.0-or-later +modules: + - ntfy +main_class: NtfyBot + +database: true +database_type: asyncpg + +config: true +extra_files: + - base-config.yaml diff --git a/ntfy/__init__.py b/ntfy/__init__.py new file mode 100644 index 0000000..484419b --- /dev/null +++ b/ntfy/__init__.py @@ -0,0 +1 @@ +from .bot import NtfyBot diff --git a/ntfy/bot.py b/ntfy/bot.py new file mode 100644 index 0000000..a4d18b7 --- /dev/null +++ b/ntfy/bot.py @@ -0,0 +1,196 @@ +import asyncio +import html +import json +from typing import Any, List, Tuple + +from aiohttp import ClientTimeout +from maubot import MessageEvent, Plugin +from maubot.handlers import command +from mautrix.types import Format, MessageType, TextMessageEventContent +from mautrix.util.async_db import UpgradeTable +from mautrix.util.config import BaseProxyConfig +from mautrix.util.formatter import parse_html + +from .config import Config +from .db import DB, Topic, upgrade_table + + +class NtfyBot(Plugin): + db: DB + config: Config + tasks: List[asyncio.Task] = [] + + async def start(self) -> None: + await super().start() + self.config.load_and_update() + self.db = DB(self.database, self.log) + await self.resubscribe() + + async def stop(self) -> None: + await super().stop() + await self.clear_subscriptions() + + async def on_external_config_update(self) -> None: + self.log.info("Refreshing configuration") + self.config.load_and_update() + + async def resubscribe(self) -> None: + await self.clear_subscriptions() + await self.subscribe_to_topics() + + async def clear_subscriptions(self) -> None: + tasks = self.tasks[:] + if not tasks: + return None + + for task in tasks: + if not task.done(): + self.log.debug("cancelling subscription task...") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception as exc: + self.log.exception("Subscription task errored", exc_info=exc) + self.tasks[:] = [] + + @command.new(name=lambda self: self.config["command_prefix"], help="Manage ntfy subscriptions.", require_subcommand=True) + async def ntfy(self) -> None: + pass + + @ntfy.subcommand("subscribe", aliases=("sub"), help="Subscribe this room to a ntfy topic.") + @command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})") + async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None: + # see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex + server, topic = topic[0].split("/") + db_topic = await self.db.get_topic(server, topic) + is_fresh_topic = False + if not db_topic: + db_topic = await self.db.create_topic(Topic(id=-1, server=server, topic=topic, last_event_id=None)) + is_fresh_topic = True + sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id) + if sub: + await evt.reply("This room is already subscribed to %s/%s", server, topic) + else: + await self.db.add_subscription(db_topic.id, evt.room_id) + await evt.reply("Subscribed this room to %s/%s", server, topic) + if is_fresh_topic: + await self.subscribe_to_topic(db_topic) + + @ntfy.subcommand("unsubscribe", aliases=("unsub"), help="Unsubscribe this room from a ntfy topic.") + @command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})") + async def unsubscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None: + # see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex + server, topic = topic[0].split("/") + db_topic = await self.db.get_topic(server, topic) + if not db_topic: + await evt.reply("This room is not subscribed to %s/%s", server, topic) + return + sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id) + if not sub: + await evt.reply("This room is not subscribed to %s/%s", server, topic) + return + await self.db.remove_subscription(db_topic.id, evt.room_id) + await evt.reply("Unsubscribed this room from %s/%s", server, topic) + + async def subscribe_to_topics(self) -> None: + topics = await self.db.get_topics() + for topic in topics: + await self.subscribe_to_topic(topic) + + async def subscribe_to_topic(self, topic: Topic) -> None: + def log_task_exc(task: asyncio.Task) -> None: + if task.done() and not task.cancelled(): + exc = task.exception() + self.log.exception( + "Subscription task errored", exc_info=exc) + # TODO: restart subscription# + + self.log.info("Subscribing to %s/%s", topic.server, topic.topic) + url = "%s/%s/json" % (topic.server, topic.topic) + if not url.startswith(("http://", "https://")): + url = "https://" + url + if topic.last_event_id: + url += "?since=%s" % topic.last_event_id + + self.log.debug("Subscribing to URL %s", url) + task = self.loop.create_task( + self.run_topic_subscription(topic, url)) + self.tasks.append(task) + task.add_done_callback(self.tasks.remove) + task.add_done_callback(log_task_exc) + + async def run_topic_subscription(self, topic: Topic, url: str) -> None: + async with self.http.get(url, timeout=ClientTimeout()) as resp: + while True: + line = await resp.content.readline() + # convert to string and remove trailing newline + line = line.decode("utf-8").strip() + self.log.debug("Received notification: %s", line) + message = json.loads(line) + if message["event"] != "message": + continue + # persist the received message id + await self.db.update_topic_id(topic.id, message["id"]) + + # build matrix message + html_content = self.build_message_content( + topic.server, message) + text_content = await parse_html(html_content.strip()) + + content = TextMessageEventContent( + msgtype=MessageType.TEXT, + format=Format.HTML, + formatted_body=html_content, + body=text_content, + ) + + subscriptions = await self.db.get_subscriptions(topic.id) + for sub in subscriptions: + try: + await self.client.send_message(sub.room_id, content) + except Exception as exc: + self.log.exception( + "Failed to send matrix message!", exc_info=exc) + + @classmethod + def build_message_content(cls, server: str, message) -> str: + topic = message["topic"] + body = message["message"] + title = message.get("title", None) + tags = message.get("tags", None) + click = message.get("click", None) + attachment = message.get("attachment", None) + + html_content = "Ntfy message in topic %s/%s
" % ( + html.escape(server), html.escape(topic)) + # build title + if title and click: + html_content += "

%s

" % ( + html.escape(click), html.escape(title)) + elif title: + html_content += "

%s

" % html.escape(title) + + # build body + if click and not title: + html_content += "%s" % (html.escape( + click), html.escape(body).replace("\n", "
")) + else: + html_content += html.escape(body).replace("\n", "
") + + # build attachment + if attachment: + html_content += "
View %s" % (html.escape( + attachment["url"]), html.escape(attachment["name"])) + html_content += "
" + + return html_content + + @classmethod + def get_config_class(cls) -> type[BaseProxyConfig]: + return Config + + @classmethod + def get_db_upgrade_table(cls) -> UpgradeTable | None: + return upgrade_table diff --git a/ntfy/config.py b/ntfy/config.py new file mode 100644 index 0000000..ae256e9 --- /dev/null +++ b/ntfy/config.py @@ -0,0 +1,6 @@ +from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper + + +class Config(BaseProxyConfig): + def do_update(self, helper: ConfigUpdateHelper) -> None: + helper.copy("command_prefix") diff --git a/ntfy/db.py b/ntfy/db.py new file mode 100644 index 0000000..0a5bb1c --- /dev/null +++ b/ntfy/db.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from typing import List, Tuple + +import attr +from asyncpg import Record +from attr import dataclass +from mautrix.types import RoomID +from mautrix.util.async_db import Connection, Database, Scheme, UpgradeTable +from mautrix.util.logging import TraceLogger + +upgrade_table = UpgradeTable() + + +@upgrade_table.register(description="Initial revision") +async def upgrade_v1(conn: Connection, scheme: Scheme) -> None: + gen = "GENERATED ALWAYS AS IDENTITY" if scheme != Scheme.SQLITE else "" + await conn.execute( + f"""CREATE TABLE topics ( + id INTEGER {gen}, + server TEXT NOT NULL, + topic TEXT NOT NULL, + last_event_id TEXT, + + PRIMARY KEY (id), + UNIQUE (server, topic) + )""" + ) + await conn.execute( + """CREATE TABLE subscriptions ( + topic_id INTEGER, + room_id TEXT NOT NULL, + + PRIMARY KEY (topic_id, room_id), + FOREIGN KEY (topic_id) REFERENCES topics (id) + )""" + ) + + +@dataclass +class Topic: + id: int + server: str + topic: str + last_event_id: str + + subscriptions: List[Subscription] = attr.ib(factory=lambda: []) + + @classmethod + def from_row(cls, row: Record | None) -> Topic | None: + if not row: + return None + id = row["id"] + server = row["server"] + topic = row["topic"] + last_event_id = row["last_event_id"] + return cls( + id=id, + server=server, + topic=topic, + last_event_id=last_event_id, + subscriptions=[] + ) + + +@dataclass +class Subscription: + topic_id: int + room_id: RoomID + + @classmethod + def from_row(cls, row: Record | None) -> Topic | None: + if not row: + return None + topic_id = row["topic_id"] + room_id = row["room_id"] + return cls( + topic_id=topic_id, + room_id=room_id, + ) + + +class DB: + db: Database + log: TraceLogger + + def __init__(self, db: Database, log: TraceLogger) -> None: + self.db = db + self.log = log + + async def get_topics(self) -> List[Topic]: + query = """ + SELECT id, server, topic, last_event_id, topic_id, room_id + FROM topics + INNER JOIN + subscriptions ON topics.id = subscriptions.topic_id + """ + rows = await self.db.fetch(query) + topics = {} + for row in rows: + try: + topic = topics[row["id"]] + except KeyError: + topic = topics[row["id"]] = Topic.from_row(row) + topic.subscriptions.append(Subscription.from_row(row)) + return list(topics.values()) + + async def update_topic_id(self, topic_id: int, event_id: str) -> None: + query = """ + UPDATE topics SET last_event_id=$2 WHERE id=$1 + """ + await self.db.execute(query, topic_id, event_id) + + async def create_topic(self, topic: Topic) -> Topic: + query = """ + INSERT INTO topics (server, topic, last_event_id) + VALUES ($1, $2, $3) RETURNING (id) + """ + if self.db.scheme == Scheme.SQLITE: + cur = await self.db.execute( + query.replace("RETURNING (id)", ""), + topic.server, + topic.topic, + topic.last_event_id, + ) + topic.id = cur.lastrowid + else: + topic.id = await self.db.fetchval( + query, + topic.server, + topic.topic, + topic.last_event_id, + ) + return topic + + async def get_topic(self, server: str, topic: str) -> Topic | None: + query = """ + SELECT id, server, topic, last_event_id + FROM topics + WHERE server = $1 AND topic = $2 + """ + return Topic.from_row(await self.db.fetchrow(query, server, topic)) + + async def get_subscription(self, topic_id: int, room_id: RoomID) -> Tuple[Subscription | None, Topic | None]: + query = """ + SELECT id, server, topic, last_event_id, topic_id, room_id + FROM topics + INNER JOIN + subscriptions ON topics.id = subscriptions.topic_id AND subscriptions.room_id = $2 + WHERE topics.id = $1 + """ + row = await self.db.fetchrow(query, topic_id, room_id) + return (Subscription.from_row(row), Topic.from_row(row)) + + async def get_subscriptions(self, topic_id: int) -> List[Subscription]: + query = """ + SELECT topic_id, room_id + FROM subscriptions + WHERE topic_id = $1 + """ + rows = await self.db.fetch(query, topic_id) + subscriptions = [] + for row in rows: + subscriptions.append(Subscription.from_row(row)) + return subscriptions + + async def add_subscription(self, topic_id: int, room_id: RoomID) -> None: + query = """ + INSERT INTO subscriptions (topic_id, room_id) + VALUES ($1, $2) + """ + await self.db.execute(query, topic_id, room_id) + + async def remove_subscription(self, topic_id: int, room_id: RoomID) -> None: + query = """ + DELETE FROM subscriptions + WHERE topic_id = $1 AND room_id = $2 + """ + await self.db.execute(query, topic_id, room_id)