More changes
This commit is contained in:
parent
0b246e44a8
commit
eef052b1e9
9 changed files with 195 additions and 61 deletions
|
@ -10,8 +10,9 @@ plugin_directories:
|
||||||
- ./plugins
|
- ./plugins
|
||||||
|
|
||||||
server:
|
server:
|
||||||
# The IP:port to listen to.
|
# The IP and port to listen to.
|
||||||
listen: 0.0.0.0:29316
|
hostname: 0.0.0.0
|
||||||
|
port: 29316
|
||||||
# The base management API path.
|
# The base management API path.
|
||||||
base_path: /_matrix/maubot
|
base_path: /_matrix/maubot
|
||||||
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.
|
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.
|
||||||
|
|
|
@ -17,9 +17,14 @@ from sqlalchemy import orm
|
||||||
import sqlalchemy as sql
|
import sqlalchemy as sql
|
||||||
import logging.config
|
import logging.config
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
|
import sys
|
||||||
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
|
from .db import Base, init as init_db
|
||||||
|
from .server import MaubotServer
|
||||||
|
from .client import Client, init as init_client
|
||||||
from .__meta__ import __version__
|
from .__meta__ import __version__
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
|
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
|
||||||
|
@ -36,7 +41,23 @@ logging.config.dictConfig(copy.deepcopy(config["logging"]))
|
||||||
log = logging.getLogger("maubot")
|
log = logging.getLogger("maubot")
|
||||||
log.debug(f"Initializing maubot {__version__}")
|
log.debug(f"Initializing maubot {__version__}")
|
||||||
|
|
||||||
db_engine = sql.create_engine(config["database"])
|
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
|
||||||
db_factory = orm.sessionmaker(bind=db_engine)
|
db_factory = orm.sessionmaker(bind=db_engine)
|
||||||
db_session = orm.scoping.scoped_session(db_factory)
|
db_session = orm.scoping.scoped_session(db_factory)
|
||||||
Base.metadata.bind=db_engine
|
Base.metadata.bind=db_engine
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
init_db(db_session)
|
||||||
|
init_client(loop)
|
||||||
|
server = MaubotServer(config, loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(server.start())
|
||||||
|
loop.run_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
log.debug("Keyboard interrupt received, stopping...")
|
||||||
|
for client in Client.cache.values():
|
||||||
|
client.stop()
|
||||||
|
loop.run_until_complete(server.stop())
|
||||||
|
sys.exit(0)
|
||||||
|
|
|
@ -13,62 +13,21 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict, List, Optional, Union, Callable
|
from typing import Dict, List, Optional
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from mautrix import Client as MatrixClient
|
from mautrix.types import UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership, EventType
|
||||||
from mautrix.client import EventHandler
|
|
||||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
|
|
||||||
EventType, MessageEvent)
|
|
||||||
|
|
||||||
from .command_spec import ParsedCommand
|
|
||||||
from .db import DBClient
|
from .db import DBClient
|
||||||
|
from .matrix import MaubotMatrixClient
|
||||||
|
|
||||||
log = logging.getLogger("maubot.client")
|
log = logging.getLogger("maubot.client")
|
||||||
|
|
||||||
|
|
||||||
class MaubotMatrixClient(MatrixClient):
|
|
||||||
def __init__(self, maubot_client: 'Client', *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._maubot_client = maubot_client
|
|
||||||
self.command_handlers: Dict[str, List[EventHandler]] = {}
|
|
||||||
self.commands: List[ParsedCommand] = []
|
|
||||||
|
|
||||||
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
|
|
||||||
|
|
||||||
async def _command_event_handler(self, evt: MessageEvent) -> None:
|
|
||||||
for command in self.commands:
|
|
||||||
if command.match(evt):
|
|
||||||
await self._trigger_command(command, evt)
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
|
|
||||||
for handler in self.command_handlers.get(command.name, []):
|
|
||||||
await handler(evt)
|
|
||||||
|
|
||||||
def on(self, var: Union[EventHandler, EventType, str]
|
|
||||||
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
|
|
||||||
if isinstance(var, str):
|
|
||||||
def decorator(func: EventHandler) -> EventHandler:
|
|
||||||
self.add_command_handler(var, func)
|
|
||||||
return func
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
return super().on(var)
|
|
||||||
|
|
||||||
def add_command_handler(self, command: str, handler: EventHandler) -> None:
|
|
||||||
self.command_handlers.setdefault(command, []).append(handler)
|
|
||||||
|
|
||||||
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
|
|
||||||
try:
|
|
||||||
self.command_handlers[command].remove(handler)
|
|
||||||
except (KeyError, ValueError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
|
loop: asyncio.AbstractEventLoop
|
||||||
cache: Dict[UserID, 'Client'] = {}
|
cache: Dict[UserID, 'Client'] = {}
|
||||||
http_client: ClientSession = None
|
http_client: ClientSession = None
|
||||||
|
|
||||||
|
@ -78,26 +37,33 @@ class Client:
|
||||||
def __init__(self, db_instance: DBClient) -> None:
|
def __init__(self, db_instance: DBClient) -> None:
|
||||||
self.db_instance = db_instance
|
self.db_instance = db_instance
|
||||||
self.cache[self.id] = self
|
self.cache[self.id] = self
|
||||||
self.client = MaubotMatrixClient(maubot_client=self,
|
self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance,
|
||||||
store=self.db_instance,
|
mxid=self.id, base_url=self.homeserver,
|
||||||
mxid=self.id,
|
token=self.access_token, client_session=self.http_client,
|
||||||
base_url=self.homeserver,
|
|
||||||
token=self.access_token,
|
|
||||||
client_session=self.http_client,
|
|
||||||
log=log.getChild(self.id))
|
log=log.getChild(self.id))
|
||||||
if self.autojoin:
|
if self.autojoin:
|
||||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
asyncio.ensure_future(self.client.start(), loop=self.loop)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self.client.stop()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, user_id: UserID) -> Optional['Client']:
|
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
|
||||||
try:
|
try:
|
||||||
return cls.cache[user_id]
|
return cls.cache[user_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
db_instance = DBClient.query.get(user_id)
|
db_instance = db_instance or DBClient.query.get(user_id)
|
||||||
if not db_instance:
|
if not db_instance:
|
||||||
return None
|
return None
|
||||||
return Client(db_instance)
|
return Client(db_instance)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def all(cls) -> List['Client']:
|
||||||
|
return [cls.get(user.id, user) for user in DBClient.query.all()]
|
||||||
|
|
||||||
# region Properties
|
# region Properties
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -176,3 +142,9 @@ class Client:
|
||||||
async def _handle_invite(self, evt: StateEvent) -> None:
|
async def _handle_invite(self, evt: StateEvent) -> None:
|
||||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||||
await self.client.join_room_by_id(evt.room_id)
|
await self.client.join_room_by_id(evt.room_id)
|
||||||
|
|
||||||
|
|
||||||
|
def init(loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
Client.loop = loop
|
||||||
|
for client in Client.all():
|
||||||
|
client.start()
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
|
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
|
||||||
from sqlalchemy.orm import Query
|
from sqlalchemy.orm import Query, scoped_session
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
@ -89,3 +89,9 @@ class DBCommandSpec(Base):
|
||||||
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
|
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
|
||||||
primary_key=True)
|
primary_key=True)
|
||||||
spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False)
|
spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
def init(session: scoped_session) -> None:
|
||||||
|
DBPlugin.query = session.query_property()
|
||||||
|
DBClient.query = session.query_property()
|
||||||
|
DBCommandSpec.query = session.query_property()
|
||||||
|
|
|
@ -13,14 +13,17 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import TypeVar, Type
|
from typing import TypeVar, Type, Dict
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from ..plugin_base import Plugin
|
from ..plugin_base import Plugin
|
||||||
|
|
||||||
PluginClass = TypeVar("PluginClass", bound=Plugin)
|
PluginClass = TypeVar("PluginClass", bound=Plugin)
|
||||||
|
|
||||||
|
|
||||||
class PluginLoader(ABC):
|
class PluginLoader(ABC):
|
||||||
|
id_cache: Dict[str, 'PluginLoader'] = {}
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,6 @@ class MaubotZipImportError(Exception):
|
||||||
|
|
||||||
class ZippedPluginLoader(PluginLoader):
|
class ZippedPluginLoader(PluginLoader):
|
||||||
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
|
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
|
||||||
id_cache: Dict[str, 'ZippedPluginLoader'] = {}
|
|
||||||
|
|
||||||
path: str
|
path: str
|
||||||
id: str
|
id: str
|
||||||
|
|
77
maubot/matrix.py
Normal file
77
maubot/matrix.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
# maubot - A plugin-based Matrix bot system.
|
||||||
|
# Copyright (C) 2018 Tulir Asokan
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Dict, List, Union, Callable
|
||||||
|
|
||||||
|
from mautrix import Client as MatrixClient
|
||||||
|
from mautrix.client import EventHandler
|
||||||
|
from mautrix.types import EventType, MessageEvent
|
||||||
|
|
||||||
|
from .command_spec import ParsedCommand, CommandSpec
|
||||||
|
|
||||||
|
|
||||||
|
class MaubotMatrixClient(MatrixClient):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.command_handlers: Dict[str, List[EventHandler]] = {}
|
||||||
|
self.commands: List[ParsedCommand] = []
|
||||||
|
self.command_specs: Dict[str, CommandSpec] = {}
|
||||||
|
|
||||||
|
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
|
||||||
|
|
||||||
|
def set_command_spec(self, plugin_id: str, spec: CommandSpec) -> None:
|
||||||
|
self.command_specs[plugin_id] = spec
|
||||||
|
self._reparse_command_specs()
|
||||||
|
|
||||||
|
def _reparse_command_specs(self) -> None:
|
||||||
|
self.commands = [parsed_command
|
||||||
|
for spec in self.command_specs.values()
|
||||||
|
for parsed_command in spec.parse()]
|
||||||
|
|
||||||
|
def remove_command_spec(self, plugin_id: str) -> None:
|
||||||
|
try:
|
||||||
|
del self.command_specs[plugin_id]
|
||||||
|
self._reparse_command_specs()
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _command_event_handler(self, evt: MessageEvent) -> None:
|
||||||
|
for command in self.commands:
|
||||||
|
if command.match(evt):
|
||||||
|
await self._trigger_command(command, evt)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
|
||||||
|
for handler in self.command_handlers.get(command.name, []):
|
||||||
|
await handler(evt)
|
||||||
|
|
||||||
|
def on(self, var: Union[EventHandler, EventType, str]
|
||||||
|
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
|
||||||
|
if isinstance(var, str):
|
||||||
|
def decorator(func: EventHandler) -> EventHandler:
|
||||||
|
self.add_command_handler(var, func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
return super().on(var)
|
||||||
|
|
||||||
|
def add_command_handler(self, command: str, handler: EventHandler) -> None:
|
||||||
|
self.command_handlers.setdefault(command, []).append(handler)
|
||||||
|
|
||||||
|
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
|
||||||
|
try:
|
||||||
|
self.command_handlers[command].remove(handler)
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
pass
|
|
@ -22,11 +22,12 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
class Plugin(ABC):
|
class Plugin(ABC):
|
||||||
def __init__(self, client: 'MaubotMatrixClient') -> None:
|
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None:
|
||||||
self.client = client
|
self.client = client
|
||||||
|
self.id = plugin_instance_id
|
||||||
|
|
||||||
def set_command_spec(self, spec: 'CommandSpec') -> None:
|
def set_command_spec(self, spec: 'CommandSpec') -> None:
|
||||||
pass
|
self.client.set_command_spec(self.id, spec)
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
54
maubot/server.py
Normal file
54
maubot/server.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# maubot - A plugin-based Matrix bot system.
|
||||||
|
# Copyright (C) 2018 Tulir Asokan
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from aiohttp import web
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from mautrix.api import PathBuilder
|
||||||
|
|
||||||
|
from .config import Config
|
||||||
|
from .__meta__ import __version__
|
||||||
|
|
||||||
|
|
||||||
|
class MaubotServer:
|
||||||
|
def __init__(self, config: Config, loop: asyncio.AbstractEventLoop):
|
||||||
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
|
self.app = web.Application(loop=self.loop)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
path = PathBuilder(config["server.base_path"])
|
||||||
|
self.app.router.add_get(path.version, self.version)
|
||||||
|
|
||||||
|
as_path = PathBuilder(config["server.appservice_base_path"])
|
||||||
|
self.app.router.add_put(as_path.transactions, self.handle_transaction)
|
||||||
|
|
||||||
|
self.runner = web.AppRunner(self.app)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
await self.runner.setup()
|
||||||
|
site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
await self.runner.cleanup()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def version(_: web.Request) -> web.Response:
|
||||||
|
return web.json_response({
|
||||||
|
"version": __version__
|
||||||
|
})
|
||||||
|
|
||||||
|
async def handle_transaction(self, request: web.Request) -> web.Response:
|
||||||
|
return web.Response(status=501)
|
Loading…
Reference in a new issue