Initial commit
This commit is contained in:
commit
704d5428a4
7 changed files with 399 additions and 0 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
.venv
|
||||
*.mbp
|
1
base-config.yaml
Normal file
1
base-config.yaml
Normal file
|
@ -0,0 +1 @@
|
|||
command_prefix: ntfy
|
14
maubot.yaml
Normal file
14
maubot.yaml
Normal file
|
@ -0,0 +1,14 @@
|
|||
maubot: 0.3.0
|
||||
id: cloud.catgirl.ntfy
|
||||
version: 0.1.0
|
||||
license: AGPL-3.0-or-later
|
||||
modules:
|
||||
- ntfy
|
||||
main_class: NtfyBot
|
||||
|
||||
database: true
|
||||
database_type: asyncpg
|
||||
|
||||
config: true
|
||||
extra_files:
|
||||
- base-config.yaml
|
1
ntfy/__init__.py
Normal file
1
ntfy/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .bot import NtfyBot
|
196
ntfy/bot.py
Normal file
196
ntfy/bot.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
import asyncio
|
||||
import html
|
||||
import json
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from aiohttp import ClientTimeout
|
||||
from maubot import MessageEvent, Plugin
|
||||
from maubot.handlers import command
|
||||
from mautrix.types import Format, MessageType, TextMessageEventContent
|
||||
from mautrix.util.async_db import UpgradeTable
|
||||
from mautrix.util.config import BaseProxyConfig
|
||||
from mautrix.util.formatter import parse_html
|
||||
|
||||
from .config import Config
|
||||
from .db import DB, Topic, upgrade_table
|
||||
|
||||
|
||||
class NtfyBot(Plugin):
|
||||
db: DB
|
||||
config: Config
|
||||
tasks: List[asyncio.Task] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
await super().start()
|
||||
self.config.load_and_update()
|
||||
self.db = DB(self.database, self.log)
|
||||
await self.resubscribe()
|
||||
|
||||
async def stop(self) -> None:
|
||||
await super().stop()
|
||||
await self.clear_subscriptions()
|
||||
|
||||
async def on_external_config_update(self) -> None:
|
||||
self.log.info("Refreshing configuration")
|
||||
self.config.load_and_update()
|
||||
|
||||
async def resubscribe(self) -> None:
|
||||
await self.clear_subscriptions()
|
||||
await self.subscribe_to_topics()
|
||||
|
||||
async def clear_subscriptions(self) -> None:
|
||||
tasks = self.tasks[:]
|
||||
if not tasks:
|
||||
return None
|
||||
|
||||
for task in tasks:
|
||||
if not task.done():
|
||||
self.log.debug("cancelling subscription task...")
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
self.log.exception("Subscription task errored", exc_info=exc)
|
||||
self.tasks[:] = []
|
||||
|
||||
@command.new(name=lambda self: self.config["command_prefix"], help="Manage ntfy subscriptions.", require_subcommand=True)
|
||||
async def ntfy(self) -> None:
|
||||
pass
|
||||
|
||||
@ntfy.subcommand("subscribe", aliases=("sub"), help="Subscribe this room to a ntfy topic.")
|
||||
@command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
|
||||
async def subscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None:
|
||||
# see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex
|
||||
server, topic = topic[0].split("/")
|
||||
db_topic = await self.db.get_topic(server, topic)
|
||||
is_fresh_topic = False
|
||||
if not db_topic:
|
||||
db_topic = await self.db.create_topic(Topic(id=-1, server=server, topic=topic, last_event_id=None))
|
||||
is_fresh_topic = True
|
||||
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
|
||||
if sub:
|
||||
await evt.reply("This room is already subscribed to %s/%s", server, topic)
|
||||
else:
|
||||
await self.db.add_subscription(db_topic.id, evt.room_id)
|
||||
await evt.reply("Subscribed this room to %s/%s", server, topic)
|
||||
if is_fresh_topic:
|
||||
await self.subscribe_to_topic(db_topic)
|
||||
|
||||
@ntfy.subcommand("unsubscribe", aliases=("unsub"), help="Unsubscribe this room from a ntfy topic.")
|
||||
@command.argument("topic", "topic URL", matches="(([a-zA-Z0-9-]{1,63}\\.)+[a-zA-Z]{2,6}/[a-zA-Z0-9_-]{1,64})")
|
||||
async def unsubscribe(self, evt: MessageEvent, topic: Tuple[str, Any]) -> None:
|
||||
# see https://github.com/binwiederhier/ntfy/blob/82df434d19e3ef45ada9c00dfe9fc0f8dfba15e6/server/server.go#L61 for the valid topic regex
|
||||
server, topic = topic[0].split("/")
|
||||
db_topic = await self.db.get_topic(server, topic)
|
||||
if not db_topic:
|
||||
await evt.reply("This room is not subscribed to %s/%s", server, topic)
|
||||
return
|
||||
sub, _ = await self.db.get_subscription(db_topic.id, evt.room_id)
|
||||
if not sub:
|
||||
await evt.reply("This room is not subscribed to %s/%s", server, topic)
|
||||
return
|
||||
await self.db.remove_subscription(db_topic.id, evt.room_id)
|
||||
await evt.reply("Unsubscribed this room from %s/%s", server, topic)
|
||||
|
||||
async def subscribe_to_topics(self) -> None:
|
||||
topics = await self.db.get_topics()
|
||||
for topic in topics:
|
||||
await self.subscribe_to_topic(topic)
|
||||
|
||||
async def subscribe_to_topic(self, topic: Topic) -> None:
|
||||
def log_task_exc(task: asyncio.Task) -> None:
|
||||
if task.done() and not task.cancelled():
|
||||
exc = task.exception()
|
||||
self.log.exception(
|
||||
"Subscription task errored", exc_info=exc)
|
||||
# TODO: restart subscription#
|
||||
|
||||
self.log.info("Subscribing to %s/%s", topic.server, topic.topic)
|
||||
url = "%s/%s/json" % (topic.server, topic.topic)
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
if topic.last_event_id:
|
||||
url += "?since=%s" % topic.last_event_id
|
||||
|
||||
self.log.debug("Subscribing to URL %s", url)
|
||||
task = self.loop.create_task(
|
||||
self.run_topic_subscription(topic, url))
|
||||
self.tasks.append(task)
|
||||
task.add_done_callback(self.tasks.remove)
|
||||
task.add_done_callback(log_task_exc)
|
||||
|
||||
async def run_topic_subscription(self, topic: Topic, url: str) -> None:
|
||||
async with self.http.get(url, timeout=ClientTimeout()) as resp:
|
||||
while True:
|
||||
line = await resp.content.readline()
|
||||
# convert to string and remove trailing newline
|
||||
line = line.decode("utf-8").strip()
|
||||
self.log.debug("Received notification: %s", line)
|
||||
message = json.loads(line)
|
||||
if message["event"] != "message":
|
||||
continue
|
||||
# persist the received message id
|
||||
await self.db.update_topic_id(topic.id, message["id"])
|
||||
|
||||
# build matrix message
|
||||
html_content = self.build_message_content(
|
||||
topic.server, message)
|
||||
text_content = await parse_html(html_content.strip())
|
||||
|
||||
content = TextMessageEventContent(
|
||||
msgtype=MessageType.TEXT,
|
||||
format=Format.HTML,
|
||||
formatted_body=html_content,
|
||||
body=text_content,
|
||||
)
|
||||
|
||||
subscriptions = await self.db.get_subscriptions(topic.id)
|
||||
for sub in subscriptions:
|
||||
try:
|
||||
await self.client.send_message(sub.room_id, content)
|
||||
except Exception as exc:
|
||||
self.log.exception(
|
||||
"Failed to send matrix message!", exc_info=exc)
|
||||
|
||||
@classmethod
|
||||
def build_message_content(cls, server: str, message) -> str:
|
||||
topic = message["topic"]
|
||||
body = message["message"]
|
||||
title = message.get("title", None)
|
||||
tags = message.get("tags", None)
|
||||
click = message.get("click", None)
|
||||
attachment = message.get("attachment", None)
|
||||
|
||||
html_content = "<span>Ntfy message in topic <code>%s/%s</code></span><blockquote>" % (
|
||||
html.escape(server), html.escape(topic))
|
||||
# build title
|
||||
if title and click:
|
||||
html_content += "<h4><a href=\"%s\">%s</a></h4>" % (
|
||||
html.escape(click), html.escape(title))
|
||||
elif title:
|
||||
html_content += "<h4>%s</h4>" % html.escape(title)
|
||||
|
||||
# build body
|
||||
if click and not title:
|
||||
html_content += "<a href=\"%s\">%s</a>" % (html.escape(
|
||||
click), html.escape(body).replace("\n", "<br />"))
|
||||
else:
|
||||
html_content += html.escape(body).replace("\n", "<br />")
|
||||
|
||||
# build attachment
|
||||
if attachment:
|
||||
html_content += "<br/><a href=\"%s\">View %s</a>" % (html.escape(
|
||||
attachment["url"]), html.escape(attachment["name"]))
|
||||
html_content += "</blockquote>"
|
||||
|
||||
return html_content
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> type[BaseProxyConfig]:
|
||||
return Config
|
||||
|
||||
@classmethod
|
||||
def get_db_upgrade_table(cls) -> UpgradeTable | None:
|
||||
return upgrade_table
|
6
ntfy/config.py
Normal file
6
ntfy/config.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
||||
|
||||
|
||||
class Config(BaseProxyConfig):
|
||||
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||
helper.copy("command_prefix")
|
179
ntfy/db.py
Normal file
179
ntfy/db.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import attr
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
from mautrix.types import RoomID
|
||||
from mautrix.util.async_db import Connection, Database, Scheme, UpgradeTable
|
||||
from mautrix.util.logging import TraceLogger
|
||||
|
||||
upgrade_table = UpgradeTable()
|
||||
|
||||
|
||||
@upgrade_table.register(description="Initial revision")
|
||||
async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
|
||||
gen = "GENERATED ALWAYS AS IDENTITY" if scheme != Scheme.SQLITE else ""
|
||||
await conn.execute(
|
||||
f"""CREATE TABLE topics (
|
||||
id INTEGER {gen},
|
||||
server TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
last_event_id TEXT,
|
||||
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE (server, topic)
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE subscriptions (
|
||||
topic_id INTEGER,
|
||||
room_id TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY (topic_id, room_id),
|
||||
FOREIGN KEY (topic_id) REFERENCES topics (id)
|
||||
)"""
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Topic:
|
||||
id: int
|
||||
server: str
|
||||
topic: str
|
||||
last_event_id: str
|
||||
|
||||
subscriptions: List[Subscription] = attr.ib(factory=lambda: [])
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: Record | None) -> Topic | None:
|
||||
if not row:
|
||||
return None
|
||||
id = row["id"]
|
||||
server = row["server"]
|
||||
topic = row["topic"]
|
||||
last_event_id = row["last_event_id"]
|
||||
return cls(
|
||||
id=id,
|
||||
server=server,
|
||||
topic=topic,
|
||||
last_event_id=last_event_id,
|
||||
subscriptions=[]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Subscription:
|
||||
topic_id: int
|
||||
room_id: RoomID
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: Record | None) -> Topic | None:
|
||||
if not row:
|
||||
return None
|
||||
topic_id = row["topic_id"]
|
||||
room_id = row["room_id"]
|
||||
return cls(
|
||||
topic_id=topic_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
|
||||
|
||||
class DB:
|
||||
db: Database
|
||||
log: TraceLogger
|
||||
|
||||
def __init__(self, db: Database, log: TraceLogger) -> None:
|
||||
self.db = db
|
||||
self.log = log
|
||||
|
||||
async def get_topics(self) -> List[Topic]:
|
||||
query = """
|
||||
SELECT id, server, topic, last_event_id, topic_id, room_id
|
||||
FROM topics
|
||||
INNER JOIN
|
||||
subscriptions ON topics.id = subscriptions.topic_id
|
||||
"""
|
||||
rows = await self.db.fetch(query)
|
||||
topics = {}
|
||||
for row in rows:
|
||||
try:
|
||||
topic = topics[row["id"]]
|
||||
except KeyError:
|
||||
topic = topics[row["id"]] = Topic.from_row(row)
|
||||
topic.subscriptions.append(Subscription.from_row(row))
|
||||
return list(topics.values())
|
||||
|
||||
async def update_topic_id(self, topic_id: int, event_id: str) -> None:
|
||||
query = """
|
||||
UPDATE topics SET last_event_id=$2 WHERE id=$1
|
||||
"""
|
||||
await self.db.execute(query, topic_id, event_id)
|
||||
|
||||
async def create_topic(self, topic: Topic) -> Topic:
|
||||
query = """
|
||||
INSERT INTO topics (server, topic, last_event_id)
|
||||
VALUES ($1, $2, $3) RETURNING (id)
|
||||
"""
|
||||
if self.db.scheme == Scheme.SQLITE:
|
||||
cur = await self.db.execute(
|
||||
query.replace("RETURNING (id)", ""),
|
||||
topic.server,
|
||||
topic.topic,
|
||||
topic.last_event_id,
|
||||
)
|
||||
topic.id = cur.lastrowid
|
||||
else:
|
||||
topic.id = await self.db.fetchval(
|
||||
query,
|
||||
topic.server,
|
||||
topic.topic,
|
||||
topic.last_event_id,
|
||||
)
|
||||
return topic
|
||||
|
||||
async def get_topic(self, server: str, topic: str) -> Topic | None:
|
||||
query = """
|
||||
SELECT id, server, topic, last_event_id
|
||||
FROM topics
|
||||
WHERE server = $1 AND topic = $2
|
||||
"""
|
||||
return Topic.from_row(await self.db.fetchrow(query, server, topic))
|
||||
|
||||
async def get_subscription(self, topic_id: int, room_id: RoomID) -> Tuple[Subscription | None, Topic | None]:
|
||||
query = """
|
||||
SELECT id, server, topic, last_event_id, topic_id, room_id
|
||||
FROM topics
|
||||
INNER JOIN
|
||||
subscriptions ON topics.id = subscriptions.topic_id AND subscriptions.room_id = $2
|
||||
WHERE topics.id = $1
|
||||
"""
|
||||
row = await self.db.fetchrow(query, topic_id, room_id)
|
||||
return (Subscription.from_row(row), Topic.from_row(row))
|
||||
|
||||
async def get_subscriptions(self, topic_id: int) -> List[Subscription]:
|
||||
query = """
|
||||
SELECT topic_id, room_id
|
||||
FROM subscriptions
|
||||
WHERE topic_id = $1
|
||||
"""
|
||||
rows = await self.db.fetch(query, topic_id)
|
||||
subscriptions = []
|
||||
for row in rows:
|
||||
subscriptions.append(Subscription.from_row(row))
|
||||
return subscriptions
|
||||
|
||||
async def add_subscription(self, topic_id: int, room_id: RoomID) -> None:
|
||||
query = """
|
||||
INSERT INTO subscriptions (topic_id, room_id)
|
||||
VALUES ($1, $2)
|
||||
"""
|
||||
await self.db.execute(query, topic_id, room_id)
|
||||
|
||||
async def remove_subscription(self, topic_id: int, room_id: RoomID) -> None:
|
||||
query = """
|
||||
DELETE FROM subscriptions
|
||||
WHERE topic_id = $1 AND room_id = $2
|
||||
"""
|
||||
await self.db.execute(query, topic_id, room_id)
|
Loading…
Reference in a new issue