Initial commit

This commit is contained in:
Sophie Tauchert 2023-01-08 14:54:27 +01:00
commit 704d5428a4
No known key found for this signature in database
GPG key ID: 52701DE5F5F51125
7 changed files with 399 additions and 0 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
.venv
*.mbp

1
base-config.yaml Normal file
View file

@ -0,0 +1 @@
command_prefix: ntfy

14
maubot.yaml Normal file
View file

@ -0,0 +1,14 @@
maubot: 0.3.0
id: cloud.catgirl.ntfy
version: 0.1.0
license: AGPL-3.0-or-later
modules:
- ntfy
main_class: NtfyBot
database: true
database_type: asyncpg
config: true
extra_files:
- base-config.yaml

1
ntfy/__init__.py Normal file
View file

@ -0,0 +1 @@
from .bot import NtfyBot

196
ntfy/bot.py Normal file
View file

@ -0,0 +1,196 @@
import asyncio
import html
import json
from typing import Any, List, Tuple
from aiohttp import ClientTimeout
from maubot import MessageEvent, Plugin
from maubot.handlers import command
from mautrix.types import Format, MessageType, TextMessageEventContent
from mautrix.util.async_db import UpgradeTable
from mautrix.util.config import BaseProxyConfig
from mautrix.util.formatter import parse_html
from .config import Config
from .db import DB, Topic, upgrade_table
class NtfyBot(Plugin):
db: DB
config: Config
tasks: List[asyncio.Task] = []
async def start(self) -> None:
await super().start()
self.config.load_and_update()
self.db = DB(self.database, self.log)
await self.resubscribe()
async def stop(self) -> None:
await super().stop()
await self.clear_subscriptions()
async def on_external_config_update(self) -> None:
self.log.info("Refreshing configuration")
self.config.load_and_update()
async def resubscribe(self) -> None:
await self.clear_subscriptions()
await self.subscribe_to_topics()
async def clear_subscriptions(self) -> None:
tasks = self.tasks[:]
if not tasks:
return None
for task in tasks:
if not task.done():
self.log.debug("cancelling subscription task...")
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception as exc:
self.log.exception("Subscription task errored", exc_info=exc)
self.tasks[:] = []
@command.new(name=lambda self: self.config["command_prefix"], help="Manage ntfy subscriptions.", require_subcommand=True)
async def ntfy(self) -> None:
pass
@ntfy.subcommand("subscribe", aliases=("sub"), help="Subscribe this room to a ntfy topic.")
@command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None:
# see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex
server, topic = topic[0].split("/")
db_topic = await self.db.get_topic(server, topic)
is_fresh_topic = False
if not db_topic:
db_topic = await self.db.create_topic(Topic(id=-1, server=server, topic=topic, last_event_id=None))
is_fresh_topic = True
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
if sub:
await evt.reply("This room is already subscribed to %s/%s", server, topic)
else:
await self.db.add_subscription(db_topic.id, evt.room_id)
await evt.reply("Subscribed this room to %s/%s", server, topic)
if is_fresh_topic:
await self.subscribe_to_topic(db_topic)
@ntfy.subcommand("unsubscribe", aliases=("unsub"), help="Unsubscribe this room from a ntfy topic.")
@command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
async def unsubscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None:
# see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex
server, topic = topic[0].split("/")
db_topic = await self.db.get_topic(server, topic)
if not db_topic:
await evt.reply("This room is not subscribed to %s/%s", server, topic)
return
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
if not sub:
await evt.reply("This room is not subscribed to %s/%s", server, topic)
return
await self.db.remove_subscription(db_topic.id, evt.room_id)
await evt.reply("Unsubscribed this room from %s/%s", server, topic)
async def subscribe_to_topics(self) -> None:
topics = await self.db.get_topics()
for topic in topics:
await self.subscribe_to_topic(topic)
async def subscribe_to_topic(self, topic: Topic) -> None:
def log_task_exc(task: asyncio.Task) -> None:
if task.done() and not task.cancelled():
exc = task.exception()
self.log.exception(
"Subscription task errored", exc_info=exc)
# TODO: restart subscription#
self.log.info("Subscribing to %s/%s", topic.server, topic.topic)
url = "%s/%s/json" % (topic.server, topic.topic)
if not url.startswith(("http://", "https://")):
url = "https://" + url
if topic.last_event_id:
url += "?since=%s" % topic.last_event_id
self.log.debug("Subscribing to URL %s", url)
task = self.loop.create_task(
self.run_topic_subscription(topic, url))
self.tasks.append(task)
task.add_done_callback(self.tasks.remove)
task.add_done_callback(log_task_exc)
async def run_topic_subscription(self, topic: Topic, url: str) -> None:
async with self.http.get(url, timeout=ClientTimeout()) as resp:
while True:
line = await resp.content.readline()
# convert to string and remove trailing newline
line = line.decode("utf-8").strip()
self.log.debug("Received notification: %s", line)
message = json.loads(line)
if message["event"] != "message":
continue
# persist the received message id
await self.db.update_topic_id(topic.id, message["id"])
# build matrix message
html_content = self.build_message_content(
topic.server, message)
text_content = await parse_html(html_content.strip())
content = TextMessageEventContent(
msgtype=MessageType.TEXT,
format=Format.HTML,
formatted_body=html_content,
body=text_content,
)
subscriptions = await self.db.get_subscriptions(topic.id)
for sub in subscriptions:
try:
await self.client.send_message(sub.room_id, content)
except Exception as exc:
self.log.exception(
"Failed to send matrix message!", exc_info=exc)
@classmethod
def build_message_content(cls, server: str, message) -> str:
topic = message["topic"]
body = message["message"]
title = message.get("title", None)
tags = message.get("tags", None)
click = message.get("click", None)
attachment = message.get("attachment", None)
html_content = "<span>Ntfy message in topic <code>%s/%s</code></span><blockquote>" % (
html.escape(server), html.escape(topic))
# build title
if title and click:
html_content += "<h4><a href=\"%s\">%s</a></h4>" % (
html.escape(click), html.escape(title))
elif title:
html_content += "<h4>%s</h4>" % html.escape(title)
# build body
if click and not title:
html_content += "<a href=\"%s\">%s</a>" % (html.escape(
click), html.escape(body).replace("\n", "<br />"))
else:
html_content += html.escape(body).replace("\n", "<br />")
# build attachment
if attachment:
html_content += "<br/><a href=\"%s\">View %s</a>" % (html.escape(
attachment["url"]), html.escape(attachment["name"]))
html_content += "</blockquote>"
return html_content
@classmethod
def get_config_class(cls) -> type[BaseProxyConfig]:
return Config
@classmethod
def get_db_upgrade_table(cls) -> UpgradeTable | None:
return upgrade_table

6
ntfy/config.py Normal file
View file

@ -0,0 +1,6 @@
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("command_prefix")

179
ntfy/db.py Normal file
View file

@ -0,0 +1,179 @@
from __future__ import annotations
from typing import List, Tuple
import attr
from asyncpg import Record
from attr import dataclass
from mautrix.types import RoomID
from mautrix.util.async_db import Connection, Database, Scheme, UpgradeTable
from mautrix.util.logging import TraceLogger
upgrade_table = UpgradeTable()
@upgrade_table.register(description="Initial revision")
async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
gen = "GENERATED ALWAYS AS IDENTITY" if scheme != Scheme.SQLITE else ""
await conn.execute(
f"""CREATE TABLE topics (
id INTEGER {gen},
server TEXT NOT NULL,
topic TEXT NOT NULL,
last_event_id TEXT,
PRIMARY KEY (id),
UNIQUE (server, topic)
)"""
)
await conn.execute(
"""CREATE TABLE subscriptions (
topic_id INTEGER,
room_id TEXT NOT NULL,
PRIMARY KEY (topic_id, room_id),
FOREIGN KEY (topic_id) REFERENCES topics (id)
)"""
)
@dataclass
class Topic:
id: int
server: str
topic: str
last_event_id: str
subscriptions: List[Subscription] = attr.ib(factory=lambda: [])
@classmethod
def from_row(cls, row: Record | None) -> Topic | None:
if not row:
return None
id = row["id"]
server = row["server"]
topic = row["topic"]
last_event_id = row["last_event_id"]
return cls(
id=id,
server=server,
topic=topic,
last_event_id=last_event_id,
subscriptions=[]
)
@dataclass
class Subscription:
topic_id: int
room_id: RoomID
@classmethod
def from_row(cls, row: Record | None) -> Topic | None:
if not row:
return None
topic_id = row["topic_id"]
room_id = row["room_id"]
return cls(
topic_id=topic_id,
room_id=room_id,
)
class DB:
db: Database
log: TraceLogger
def __init__(self, db: Database, log: TraceLogger) -> None:
self.db = db
self.log = log
async def get_topics(self) -> List[Topic]:
query = """
SELECT id, server, topic, last_event_id, topic_id, room_id
FROM topics
INNER JOIN
subscriptions ON topics.id = subscriptions.topic_id
"""
rows = await self.db.fetch(query)
topics = {}
for row in rows:
try:
topic = topics[row["id"]]
except KeyError:
topic = topics[row["id"]] = Topic.from_row(row)
topic.subscriptions.append(Subscription.from_row(row))
return list(topics.values())
async def update_topic_id(self, topic_id: int, event_id: str) -> None:
query = """
UPDATE topics SET last_event_id=$2 WHERE id=$1
"""
await self.db.execute(query, topic_id, event_id)
async def create_topic(self, topic: Topic) -> Topic:
query = """
INSERT INTO topics (server, topic, last_event_id)
VALUES ($1, $2, $3) RETURNING (id)
"""
if self.db.scheme == Scheme.SQLITE:
cur = await self.db.execute(
query.replace("RETURNING (id)", ""),
topic.server,
topic.topic,
topic.last_event_id,
)
topic.id = cur.lastrowid
else:
topic.id = await self.db.fetchval(
query,
topic.server,
topic.topic,
topic.last_event_id,
)
return topic
async def get_topic(self, server: str, topic: str) -> Topic | None:
query = """
SELECT id, server, topic, last_event_id
FROM topics
WHERE server = $1 AND topic = $2
"""
return Topic.from_row(await self.db.fetchrow(query, server, topic))
async def get_subscription(self, topic_id: int, room_id: RoomID) -> Tuple[Subscription | None, Topic | None]:
query = """
SELECT id, server, topic, last_event_id, topic_id, room_id
FROM topics
INNER JOIN
subscriptions ON topics.id = subscriptions.topic_id AND subscriptions.room_id = $2
WHERE topics.id = $1
"""
row = await self.db.fetchrow(query, topic_id, room_id)
return (Subscription.from_row(row), Topic.from_row(row))
async def get_subscriptions(self, topic_id: int) -> List[Subscription]:
query = """
SELECT topic_id, room_id
FROM subscriptions
WHERE topic_id = $1
"""
rows = await self.db.fetch(query, topic_id)
subscriptions = []
for row in rows:
subscriptions.append(Subscription.from_row(row))
return subscriptions
async def add_subscription(self, topic_id: int, room_id: RoomID) -> None:
query = """
INSERT INTO subscriptions (topic_id, room_id)
VALUES ($1, $2)
"""
await self.db.execute(query, topic_id, room_id)
async def remove_subscription(self, topic_id: int, room_id: RoomID) -> None:
query = """
DELETE FROM subscriptions
WHERE topic_id = $1 AND room_id = $2
"""
await self.db.execute(query, topic_id, room_id)