Make it work

This commit is contained in:
Tulir Asokan 2018-10-17 01:30:08 +03:00
parent dce2771588
commit 1d8de8b5f2
7 changed files with 139 additions and 32 deletions

View file

@ -20,13 +20,13 @@ import argparse
import asyncio import asyncio
import copy import copy
import sys import sys
import os
from .config import Config from .config import Config
from .db import Base, init as init_db from .db import Base, init as init_db
from .server import MaubotServer from .server import MaubotServer
from .client import Client, init as init_client from .client import Client, init as init_client
from .loader import ZippedPluginLoader, MaubotZipImportError from .loader import ZippedPluginLoader
from .plugin import PluginInstance
from .__meta__ import __version__ from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.", parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
@ -57,27 +57,22 @@ loop = asyncio.get_event_loop()
init_db(db_session) init_db(db_session)
init_client(loop) init_client(loop)
server = MaubotServer(config, loop) server = MaubotServer(config, loop)
ZippedPluginLoader.load_all(*config["plugin_directories"])
plugins = PluginInstance.all()
loader_log = logging.getLogger("maubot.loader.zip") for plugin in plugins:
loader_log.debug("Preloading plugins...") plugin.load()
for directory in config["plugin_directories"]:
for file in os.listdir(directory):
if not file.endswith(".mbp"):
continue
path = os.path.join(directory, file)
try:
loader = ZippedPluginLoader.get(path)
loader_log.debug(f"Preloaded plugin {loader.id} from {loader.path}.")
except MaubotZipImportError:
loader_log.exception(f"Failed to load plugin at {path}.")
try: try:
loop.run_until_complete(server.start()) loop.run_until_complete(asyncio.gather(
server.start(),
*[plugin.start() for plugin in plugins]))
log.debug("Startup actions complete, running forever.") log.debug("Startup actions complete, running forever.")
loop.run_forever() loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
log.debug("Keyboard interrupt received, stopping...") log.debug("Keyboard interrupt received, stopping...")
for client in Client.cache.values(): for client in Client.cache.values():
client.stop() client.stop()
db_session.commit()
loop.run_until_complete(server.stop()) loop.run_until_complete(server.stop())
sys.exit(0) sys.exit(0)

View file

@ -37,15 +37,21 @@ class Client:
def __init__(self, db_instance: DBClient) -> None: def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance self.db_instance = db_instance
self.cache[self.id] = self self.cache[self.id] = self
self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance, self.log = log.getChild(self.id)
mxid=self.id, base_url=self.homeserver, self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client, token=self.access_token, client_session=self.http_client,
log=log.getChild(self.id)) log=self.log, loop=self.loop, store=self.db_instance)
if self.autojoin: if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER) self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
def start(self) -> None: def start(self) -> None:
asyncio.ensure_future(self.client.start(), loop=self.loop) asyncio.ensure_future(self._start(), loop=self.loop)
async def _start(self) -> None:
try:
await self.client.start()
except Exception:
self.log.exception("Fail")
def stop(self) -> None: def stop(self) -> None:
self.client.stop() self.client.stop()
@ -64,6 +70,10 @@ class Client:
def all(cls) -> List['Client']: def all(cls) -> List['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()] return [cls.get(user.id, user) for user in DBClient.query.all()]
async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)
# region Properties # region Properties
@property @property
@ -72,7 +82,7 @@ class Client:
@property @property
def homeserver(self) -> str: def homeserver(self) -> str:
return self.db_instance.id return self.db_instance.homeserver
@property @property
def access_token(self) -> str: def access_token(self) -> str:
@ -139,12 +149,9 @@ class Client:
# endregion # endregion
async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)
def init(loop: asyncio.AbstractEventLoop) -> None: def init(loop: asyncio.AbstractEventLoop) -> None:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop Client.loop = loop
for client in Client.all(): for client in Client.all():
client.start() client.start()

View file

@ -118,6 +118,12 @@ class zipimporter:
self._files = _read_directory(self.archive) self._files = _read_directory(self.archive)
_zip_directory_cache[self.archive] = self._files _zip_directory_cache[self.archive] = self._files
def remove_cache(self):
try:
del _zip_directory_cache[self.archive]
except KeyError:
pass
# Check whether we can satisfy the import of the module named by # Check whether we can satisfy the import of the module named by
# 'fullname', or whether it could be a portion of a namespace # 'fullname', or whether it could be a portion of a namespace
# package. Return self if we can load it, a string containing the # package. Return self if we can load it, a string containing the

View file

@ -13,11 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TypeVar, Type, Dict from typing import TypeVar, Type, Dict, Set, TYPE_CHECKING
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ..plugin_base import Plugin from ..plugin_base import Plugin
if TYPE_CHECKING:
from ..plugin import PluginInstance
PluginClass = TypeVar("PluginClass", bound=Plugin) PluginClass = TypeVar("PluginClass", bound=Plugin)
@ -28,9 +31,17 @@ class IDConflictError(Exception):
class PluginLoader(ABC): class PluginLoader(ABC):
id_cache: Dict[str, 'PluginLoader'] = {} id_cache: Dict[str, 'PluginLoader'] = {}
references: Set['PluginInstance']
id: str id: str
version: str version: str
def __init__(self):
self.references = set()
@classmethod
def find(cls, plugin_id: str) -> 'PluginLoader':
return cls.id_cache[plugin_id]
@property @property
@abstractmethod @abstractmethod
def source(self) -> str: def source(self) -> str:

View file

@ -15,8 +15,10 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Type from typing import Dict, List, Type
from zipfile import ZipFile, BadZipFile from zipfile import ZipFile, BadZipFile
import sys
import configparser import configparser
import logging
import sys
import os
from ..lib.zipimport import zipimporter, ZipImportError from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin from ..plugin_base import Plugin
@ -29,6 +31,7 @@ class MaubotZipImportError(Exception):
class ZippedPluginLoader(PluginLoader): class ZippedPluginLoader(PluginLoader):
path_cache: Dict[str, 'ZippedPluginLoader'] = {} path_cache: Dict[str, 'ZippedPluginLoader'] = {}
log = logging.getLogger("maubot.loader.zip")
path: str path: str
id: str id: str
@ -40,9 +43,11 @@ class ZippedPluginLoader(PluginLoader):
_importer: zipimporter _importer: zipimporter
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
super().__init__()
self.path = path self.path = path
self.id = None self.id = None
self._loaded = None self._loaded = None
self._importer = None
self._load_meta() self._load_meta()
self._run_preload_checks(self._get_importer()) self._run_preload_checks(self._get_importer())
try: try:
@ -52,6 +57,7 @@ class ZippedPluginLoader(PluginLoader):
pass pass
self.path_cache[self.path] = self self.path_cache[self.path] = self
self.id_cache[self.id] = self self.id_cache[self.id] = self
self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
@classmethod @classmethod
def get(cls, path: str) -> 'ZippedPluginLoader': def get(cls, path: str) -> 'ZippedPluginLoader':
@ -68,7 +74,7 @@ class ZippedPluginLoader(PluginLoader):
return ("<ZippedPlugin " return ("<ZippedPlugin "
f"path='{self.path}' " f"path='{self.path}' "
f"id='{self.id}' " f"id='{self.id}' "
f"loaded={self._loaded}>") f"loaded={self._loaded is not None}>")
def _load_meta(self) -> None: def _load_meta(self) -> None:
try: try:
@ -100,10 +106,11 @@ class ZippedPluginLoader(PluginLoader):
def _get_importer(self, reset_cache: bool = False) -> zipimporter: def _get_importer(self, reset_cache: bool = False) -> zipimporter:
try: try:
importer = zipimporter(self.path) if not self._importer:
self._importer = zipimporter(self.path)
if reset_cache: if reset_cache:
importer.reset_cache() self._importer.reset_cache()
return importer return self._importer
except ZipImportError as e: except ZipImportError as e:
raise MaubotZipImportError("File not found or not a maubot plugin") from e raise MaubotZipImportError("File not found or not a maubot plugin") from e
@ -127,6 +134,8 @@ class ZippedPluginLoader(PluginLoader):
return self._loaded return self._loaded
importer = self._get_importer(reset_cache=reset_cache) importer = self._get_importer(reset_cache=reset_cache)
self._run_preload_checks(importer) self._run_preload_checks(importer)
if reset_cache:
self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
for module in self.modules: for module in self.modules:
importer.load_module(module) importer.load_module(module)
main_mod = sys.modules[self.main_module] main_mod = sys.modules[self.main_module]
@ -134,6 +143,7 @@ class ZippedPluginLoader(PluginLoader):
if not issubclass(plugin, Plugin): if not issubclass(plugin, Plugin):
raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin") raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin")
self._loaded = plugin self._loaded = plugin
self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
return plugin return plugin
def reload(self) -> Type[PluginClass]: def reload(self) -> Type[PluginClass]:
@ -144,6 +154,8 @@ class ZippedPluginLoader(PluginLoader):
for name, mod in list(sys.modules.items()): for name, mod in list(sys.modules.items()):
if getattr(mod, "__file__", "").startswith(self.path): if getattr(mod, "__file__", "").startswith(self.path):
del sys.modules[name] del sys.modules[name]
self._loaded = None
self.log.debug(f"Unloaded plugin {self.id} at {self.path}")
def destroy(self) -> None: def destroy(self) -> None:
self.unload() self.unload()
@ -155,3 +167,24 @@ class ZippedPluginLoader(PluginLoader):
del self.id_cache[self.id] del self.id_cache[self.id]
except KeyError: except KeyError:
pass pass
self.id = None
self.path = None
self.version = None
self.modules = None
if self._importer:
self._importer.remove_cache()
self._importer = None
self._loaded = None
@classmethod
def load_all(cls, *args: str) -> None:
cls.log.debug("Preloading plugins...")
for directory in args:
for file in os.listdir(directory):
if not file.endswith(".mbp"):
continue
path = os.path.join(directory, file)
try:
ZippedPluginLoader.get(path)
except (MaubotZipImportError, IDConflictError):
cls.log.exception(f"Failed to load plugin at {path}")

View file

@ -13,12 +13,15 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List from typing import Dict, List, Optional
import logging import logging
from mautrix.types import UserID from mautrix.types import UserID
from .db import DBPlugin from .db import DBPlugin
from .client import Client
from .loader import PluginLoader
from .plugin_base import Plugin
log = logging.getLogger("maubot.plugin") log = logging.getLogger("maubot.plugin")
@ -27,10 +30,56 @@ class PluginInstance:
cache: Dict[str, 'PluginInstance'] = {} cache: Dict[str, 'PluginInstance'] = {}
plugin_directories: List[str] = [] plugin_directories: List[str] = []
log: logging.Logger
loader: PluginLoader
client: Client
plugin: Plugin
def __init__(self, db_instance: DBPlugin): def __init__(self, db_instance: DBPlugin):
self.db_instance = db_instance self.db_instance = db_instance
self.log = logging.getLogger(f"maubot.plugin.{self.id}")
self.cache[self.id] = self self.cache[self.id] = self
def load(self) -> None:
try:
self.loader = PluginLoader.find(self.type)
except KeyError:
self.log.error(f"Failed to find loader for type {self.type}")
self.db_instance.enabled = False
return
self.client = Client.get(self.primary_user)
if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}")
self.db_instance.enabled = False
async def start(self) -> None:
self.log.debug(f"Starting...")
cls = self.loader.load()
self.plugin = cls(self.client.client, self.id, self.log)
self.loader.references |= {self}
await self.plugin.start()
async def stop(self) -> None:
self.log.debug("Stopping...")
self.loader.references -= {self}
await self.plugin.stop()
self.plugin = None
@classmethod
def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None
) -> Optional['PluginInstance']:
try:
return cls.cache[instance_id]
except KeyError:
db_instance = db_instance or DBPlugin.query.get(instance_id)
if not db_instance:
return None
return PluginInstance(db_instance)
@classmethod
def all(cls) -> List['PluginInstance']:
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
# region Properties # region Properties
@property @property

View file

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from logging import Logger
from abc import ABC from abc import ABC
if TYPE_CHECKING: if TYPE_CHECKING:
@ -22,9 +23,14 @@ if TYPE_CHECKING:
class Plugin(ABC): class Plugin(ABC):
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None: client: 'MaubotMatrixClient'
id: str
log: Logger
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger) -> None:
self.client = client self.client = client
self.id = plugin_instance_id self.id = plugin_instance_id
self.log = log
def set_command_spec(self, spec: 'CommandSpec') -> None: def set_command_spec(self, spec: 'CommandSpec') -> None:
self.client.set_command_spec(self.id, spec) self.client.set_command_spec(self.id, spec)