Make it work
This commit is contained in:
parent
dce2771588
commit
1d8de8b5f2
7 changed files with 139 additions and 32 deletions
|
@ -20,13 +20,13 @@ import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import sys
|
import sys
|
||||||
import os
|
|
||||||
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .db import Base, init as init_db
|
from .db import Base, init as init_db
|
||||||
from .server import MaubotServer
|
from .server import MaubotServer
|
||||||
from .client import Client, init as init_client
|
from .client import Client, init as init_client
|
||||||
from .loader import ZippedPluginLoader, MaubotZipImportError
|
from .loader import ZippedPluginLoader
|
||||||
|
from .plugin import PluginInstance
|
||||||
from .__meta__ import __version__
|
from .__meta__ import __version__
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
|
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
|
||||||
|
@ -57,27 +57,22 @@ loop = asyncio.get_event_loop()
|
||||||
init_db(db_session)
|
init_db(db_session)
|
||||||
init_client(loop)
|
init_client(loop)
|
||||||
server = MaubotServer(config, loop)
|
server = MaubotServer(config, loop)
|
||||||
|
ZippedPluginLoader.load_all(*config["plugin_directories"])
|
||||||
|
plugins = PluginInstance.all()
|
||||||
|
|
||||||
loader_log = logging.getLogger("maubot.loader.zip")
|
for plugin in plugins:
|
||||||
loader_log.debug("Preloading plugins...")
|
plugin.load()
|
||||||
for directory in config["plugin_directories"]:
|
|
||||||
for file in os.listdir(directory):
|
|
||||||
if not file.endswith(".mbp"):
|
|
||||||
continue
|
|
||||||
path = os.path.join(directory, file)
|
|
||||||
try:
|
|
||||||
loader = ZippedPluginLoader.get(path)
|
|
||||||
loader_log.debug(f"Preloaded plugin {loader.id} from {loader.path}.")
|
|
||||||
except MaubotZipImportError:
|
|
||||||
loader_log.exception(f"Failed to load plugin at {path}.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(server.start())
|
loop.run_until_complete(asyncio.gather(
|
||||||
|
server.start(),
|
||||||
|
*[plugin.start() for plugin in plugins]))
|
||||||
log.debug("Startup actions complete, running forever.")
|
log.debug("Startup actions complete, running forever.")
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
log.debug("Keyboard interrupt received, stopping...")
|
log.debug("Keyboard interrupt received, stopping...")
|
||||||
for client in Client.cache.values():
|
for client in Client.cache.values():
|
||||||
client.stop()
|
client.stop()
|
||||||
|
db_session.commit()
|
||||||
loop.run_until_complete(server.stop())
|
loop.run_until_complete(server.stop())
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
|
@ -37,15 +37,21 @@ class Client:
|
||||||
def __init__(self, db_instance: DBClient) -> None:
|
def __init__(self, db_instance: DBClient) -> None:
|
||||||
self.db_instance = db_instance
|
self.db_instance = db_instance
|
||||||
self.cache[self.id] = self
|
self.cache[self.id] = self
|
||||||
self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance,
|
self.log = log.getChild(self.id)
|
||||||
mxid=self.id, base_url=self.homeserver,
|
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||||
token=self.access_token, client_session=self.http_client,
|
token=self.access_token, client_session=self.http_client,
|
||||||
log=log.getChild(self.id))
|
log=self.log, loop=self.loop, store=self.db_instance)
|
||||||
if self.autojoin:
|
if self.autojoin:
|
||||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
asyncio.ensure_future(self.client.start(), loop=self.loop)
|
asyncio.ensure_future(self._start(), loop=self.loop)
|
||||||
|
|
||||||
|
async def _start(self) -> None:
|
||||||
|
try:
|
||||||
|
await self.client.start()
|
||||||
|
except Exception:
|
||||||
|
self.log.exception("Fail")
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
self.client.stop()
|
self.client.stop()
|
||||||
|
@ -64,6 +70,10 @@ class Client:
|
||||||
def all(cls) -> List['Client']:
|
def all(cls) -> List['Client']:
|
||||||
return [cls.get(user.id, user) for user in DBClient.query.all()]
|
return [cls.get(user.id, user) for user in DBClient.query.all()]
|
||||||
|
|
||||||
|
async def _handle_invite(self, evt: StateEvent) -> None:
|
||||||
|
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||||
|
await self.client.join_room_by_id(evt.room_id)
|
||||||
|
|
||||||
# region Properties
|
# region Properties
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -72,7 +82,7 @@ class Client:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def homeserver(self) -> str:
|
def homeserver(self) -> str:
|
||||||
return self.db_instance.id
|
return self.db_instance.homeserver
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def access_token(self) -> str:
|
def access_token(self) -> str:
|
||||||
|
@ -139,12 +149,9 @@ class Client:
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
async def _handle_invite(self, evt: StateEvent) -> None:
|
|
||||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
|
||||||
await self.client.join_room_by_id(evt.room_id)
|
|
||||||
|
|
||||||
|
|
||||||
def init(loop: asyncio.AbstractEventLoop) -> None:
|
def init(loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
Client.http_client = ClientSession(loop=loop)
|
||||||
Client.loop = loop
|
Client.loop = loop
|
||||||
for client in Client.all():
|
for client in Client.all():
|
||||||
client.start()
|
client.start()
|
||||||
|
|
|
@ -118,6 +118,12 @@ class zipimporter:
|
||||||
self._files = _read_directory(self.archive)
|
self._files = _read_directory(self.archive)
|
||||||
_zip_directory_cache[self.archive] = self._files
|
_zip_directory_cache[self.archive] = self._files
|
||||||
|
|
||||||
|
def remove_cache(self):
|
||||||
|
try:
|
||||||
|
del _zip_directory_cache[self.archive]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Check whether we can satisfy the import of the module named by
|
# Check whether we can satisfy the import of the module named by
|
||||||
# 'fullname', or whether it could be a portion of a namespace
|
# 'fullname', or whether it could be a portion of a namespace
|
||||||
# package. Return self if we can load it, a string containing the
|
# package. Return self if we can load it, a string containing the
|
||||||
|
|
|
@ -13,11 +13,14 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import TypeVar, Type, Dict
|
from typing import TypeVar, Type, Dict, Set, TYPE_CHECKING
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from ..plugin_base import Plugin
|
from ..plugin_base import Plugin
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..plugin import PluginInstance
|
||||||
|
|
||||||
PluginClass = TypeVar("PluginClass", bound=Plugin)
|
PluginClass = TypeVar("PluginClass", bound=Plugin)
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,9 +31,17 @@ class IDConflictError(Exception):
|
||||||
class PluginLoader(ABC):
|
class PluginLoader(ABC):
|
||||||
id_cache: Dict[str, 'PluginLoader'] = {}
|
id_cache: Dict[str, 'PluginLoader'] = {}
|
||||||
|
|
||||||
|
references: Set['PluginInstance']
|
||||||
id: str
|
id: str
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.references = set()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find(cls, plugin_id: str) -> 'PluginLoader':
|
||||||
|
return cls.id_cache[plugin_id]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def source(self) -> str:
|
def source(self) -> str:
|
||||||
|
|
|
@ -15,8 +15,10 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict, List, Type
|
from typing import Dict, List, Type
|
||||||
from zipfile import ZipFile, BadZipFile
|
from zipfile import ZipFile, BadZipFile
|
||||||
import sys
|
|
||||||
import configparser
|
import configparser
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
from ..lib.zipimport import zipimporter, ZipImportError
|
from ..lib.zipimport import zipimporter, ZipImportError
|
||||||
from ..plugin_base import Plugin
|
from ..plugin_base import Plugin
|
||||||
|
@ -29,6 +31,7 @@ class MaubotZipImportError(Exception):
|
||||||
|
|
||||||
class ZippedPluginLoader(PluginLoader):
|
class ZippedPluginLoader(PluginLoader):
|
||||||
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
|
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
|
||||||
|
log = logging.getLogger("maubot.loader.zip")
|
||||||
|
|
||||||
path: str
|
path: str
|
||||||
id: str
|
id: str
|
||||||
|
@ -40,9 +43,11 @@ class ZippedPluginLoader(PluginLoader):
|
||||||
_importer: zipimporter
|
_importer: zipimporter
|
||||||
|
|
||||||
def __init__(self, path: str) -> None:
|
def __init__(self, path: str) -> None:
|
||||||
|
super().__init__()
|
||||||
self.path = path
|
self.path = path
|
||||||
self.id = None
|
self.id = None
|
||||||
self._loaded = None
|
self._loaded = None
|
||||||
|
self._importer = None
|
||||||
self._load_meta()
|
self._load_meta()
|
||||||
self._run_preload_checks(self._get_importer())
|
self._run_preload_checks(self._get_importer())
|
||||||
try:
|
try:
|
||||||
|
@ -52,6 +57,7 @@ class ZippedPluginLoader(PluginLoader):
|
||||||
pass
|
pass
|
||||||
self.path_cache[self.path] = self
|
self.path_cache[self.path] = self
|
||||||
self.id_cache[self.id] = self
|
self.id_cache[self.id] = self
|
||||||
|
self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, path: str) -> 'ZippedPluginLoader':
|
def get(cls, path: str) -> 'ZippedPluginLoader':
|
||||||
|
@ -68,7 +74,7 @@ class ZippedPluginLoader(PluginLoader):
|
||||||
return ("<ZippedPlugin "
|
return ("<ZippedPlugin "
|
||||||
f"path='{self.path}' "
|
f"path='{self.path}' "
|
||||||
f"id='{self.id}' "
|
f"id='{self.id}' "
|
||||||
f"loaded={self._loaded}>")
|
f"loaded={self._loaded is not None}>")
|
||||||
|
|
||||||
def _load_meta(self) -> None:
|
def _load_meta(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -100,10 +106,11 @@ class ZippedPluginLoader(PluginLoader):
|
||||||
|
|
||||||
def _get_importer(self, reset_cache: bool = False) -> zipimporter:
|
def _get_importer(self, reset_cache: bool = False) -> zipimporter:
|
||||||
try:
|
try:
|
||||||
importer = zipimporter(self.path)
|
if not self._importer:
|
||||||
|
self._importer = zipimporter(self.path)
|
||||||
if reset_cache:
|
if reset_cache:
|
||||||
importer.reset_cache()
|
self._importer.reset_cache()
|
||||||
return importer
|
return self._importer
|
||||||
except ZipImportError as e:
|
except ZipImportError as e:
|
||||||
raise MaubotZipImportError("File not found or not a maubot plugin") from e
|
raise MaubotZipImportError("File not found or not a maubot plugin") from e
|
||||||
|
|
||||||
|
@ -127,6 +134,8 @@ class ZippedPluginLoader(PluginLoader):
|
||||||
return self._loaded
|
return self._loaded
|
||||||
importer = self._get_importer(reset_cache=reset_cache)
|
importer = self._get_importer(reset_cache=reset_cache)
|
||||||
self._run_preload_checks(importer)
|
self._run_preload_checks(importer)
|
||||||
|
if reset_cache:
|
||||||
|
self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
|
||||||
for module in self.modules:
|
for module in self.modules:
|
||||||
importer.load_module(module)
|
importer.load_module(module)
|
||||||
main_mod = sys.modules[self.main_module]
|
main_mod = sys.modules[self.main_module]
|
||||||
|
@ -134,6 +143,7 @@ class ZippedPluginLoader(PluginLoader):
|
||||||
if not issubclass(plugin, Plugin):
|
if not issubclass(plugin, Plugin):
|
||||||
raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin")
|
raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin")
|
||||||
self._loaded = plugin
|
self._loaded = plugin
|
||||||
|
self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
|
||||||
return plugin
|
return plugin
|
||||||
|
|
||||||
def reload(self) -> Type[PluginClass]:
|
def reload(self) -> Type[PluginClass]:
|
||||||
|
@ -144,6 +154,8 @@ class ZippedPluginLoader(PluginLoader):
|
||||||
for name, mod in list(sys.modules.items()):
|
for name, mod in list(sys.modules.items()):
|
||||||
if getattr(mod, "__file__", "").startswith(self.path):
|
if getattr(mod, "__file__", "").startswith(self.path):
|
||||||
del sys.modules[name]
|
del sys.modules[name]
|
||||||
|
self._loaded = None
|
||||||
|
self.log.debug(f"Unloaded plugin {self.id} at {self.path}")
|
||||||
|
|
||||||
def destroy(self) -> None:
|
def destroy(self) -> None:
|
||||||
self.unload()
|
self.unload()
|
||||||
|
@ -155,3 +167,24 @@ class ZippedPluginLoader(PluginLoader):
|
||||||
del self.id_cache[self.id]
|
del self.id_cache[self.id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
self.id = None
|
||||||
|
self.path = None
|
||||||
|
self.version = None
|
||||||
|
self.modules = None
|
||||||
|
if self._importer:
|
||||||
|
self._importer.remove_cache()
|
||||||
|
self._importer = None
|
||||||
|
self._loaded = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_all(cls, *args: str) -> None:
|
||||||
|
cls.log.debug("Preloading plugins...")
|
||||||
|
for directory in args:
|
||||||
|
for file in os.listdir(directory):
|
||||||
|
if not file.endswith(".mbp"):
|
||||||
|
continue
|
||||||
|
path = os.path.join(directory, file)
|
||||||
|
try:
|
||||||
|
ZippedPluginLoader.get(path)
|
||||||
|
except (MaubotZipImportError, IDConflictError):
|
||||||
|
cls.log.exception(f"Failed to load plugin at {path}")
|
||||||
|
|
|
@ -13,12 +13,15 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from mautrix.types import UserID
|
from mautrix.types import UserID
|
||||||
|
|
||||||
from .db import DBPlugin
|
from .db import DBPlugin
|
||||||
|
from .client import Client
|
||||||
|
from .loader import PluginLoader
|
||||||
|
from .plugin_base import Plugin
|
||||||
|
|
||||||
log = logging.getLogger("maubot.plugin")
|
log = logging.getLogger("maubot.plugin")
|
||||||
|
|
||||||
|
@ -27,10 +30,56 @@ class PluginInstance:
|
||||||
cache: Dict[str, 'PluginInstance'] = {}
|
cache: Dict[str, 'PluginInstance'] = {}
|
||||||
plugin_directories: List[str] = []
|
plugin_directories: List[str] = []
|
||||||
|
|
||||||
|
log: logging.Logger
|
||||||
|
loader: PluginLoader
|
||||||
|
client: Client
|
||||||
|
plugin: Plugin
|
||||||
|
|
||||||
def __init__(self, db_instance: DBPlugin):
|
def __init__(self, db_instance: DBPlugin):
|
||||||
self.db_instance = db_instance
|
self.db_instance = db_instance
|
||||||
|
self.log = logging.getLogger(f"maubot.plugin.{self.id}")
|
||||||
self.cache[self.id] = self
|
self.cache[self.id] = self
|
||||||
|
|
||||||
|
def load(self) -> None:
|
||||||
|
try:
|
||||||
|
self.loader = PluginLoader.find(self.type)
|
||||||
|
except KeyError:
|
||||||
|
self.log.error(f"Failed to find loader for type {self.type}")
|
||||||
|
self.db_instance.enabled = False
|
||||||
|
return
|
||||||
|
self.client = Client.get(self.primary_user)
|
||||||
|
if not self.client:
|
||||||
|
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||||
|
self.db_instance.enabled = False
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self.log.debug(f"Starting...")
|
||||||
|
cls = self.loader.load()
|
||||||
|
self.plugin = cls(self.client.client, self.id, self.log)
|
||||||
|
self.loader.references |= {self}
|
||||||
|
await self.plugin.start()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self.log.debug("Stopping...")
|
||||||
|
self.loader.references -= {self}
|
||||||
|
await self.plugin.stop()
|
||||||
|
self.plugin = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None
|
||||||
|
) -> Optional['PluginInstance']:
|
||||||
|
try:
|
||||||
|
return cls.cache[instance_id]
|
||||||
|
except KeyError:
|
||||||
|
db_instance = db_instance or DBPlugin.query.get(instance_id)
|
||||||
|
if not db_instance:
|
||||||
|
return None
|
||||||
|
return PluginInstance(db_instance)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def all(cls) -> List['PluginInstance']:
|
||||||
|
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
|
||||||
|
|
||||||
# region Properties
|
# region Properties
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from logging import Logger
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -22,9 +23,14 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
class Plugin(ABC):
|
class Plugin(ABC):
|
||||||
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None:
|
client: 'MaubotMatrixClient'
|
||||||
|
id: str
|
||||||
|
log: Logger
|
||||||
|
|
||||||
|
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger) -> None:
|
||||||
self.client = client
|
self.client = client
|
||||||
self.id = plugin_instance_id
|
self.id = plugin_instance_id
|
||||||
|
self.log = log
|
||||||
|
|
||||||
def set_command_spec(self, spec: 'CommandSpec') -> None:
|
def set_command_spec(self, spec: 'CommandSpec') -> None:
|
||||||
self.client.set_command_spec(self.id, spec)
|
self.client.set_command_spec(self.id, spec)
|
||||||
|
|
Loading…
Reference in a new issue