Make it work

This commit is contained in:
Tulir Asokan 2018-10-17 01:30:08 +03:00
parent dce2771588
commit 1d8de8b5f2
7 changed files with 139 additions and 32 deletions

View file

@ -20,13 +20,13 @@ import argparse
import asyncio
import copy
import sys
import os
from .config import Config
from .db import Base, init as init_db
from .server import MaubotServer
from .client import Client, init as init_client
from .loader import ZippedPluginLoader, MaubotZipImportError
from .loader import ZippedPluginLoader
from .plugin import PluginInstance
from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
@ -57,27 +57,22 @@ loop = asyncio.get_event_loop()
init_db(db_session)
init_client(loop)
server = MaubotServer(config, loop)
ZippedPluginLoader.load_all(*config["plugin_directories"])
plugins = PluginInstance.all()
loader_log = logging.getLogger("maubot.loader.zip")
loader_log.debug("Preloading plugins...")
for directory in config["plugin_directories"]:
for file in os.listdir(directory):
if not file.endswith(".mbp"):
continue
path = os.path.join(directory, file)
try:
loader = ZippedPluginLoader.get(path)
loader_log.debug(f"Preloaded plugin {loader.id} from {loader.path}.")
except MaubotZipImportError:
loader_log.exception(f"Failed to load plugin at {path}.")
for plugin in plugins:
plugin.load()
try:
loop.run_until_complete(server.start())
loop.run_until_complete(asyncio.gather(
server.start(),
*[plugin.start() for plugin in plugins]))
log.debug("Startup actions complete, running forever.")
loop.run_forever()
except KeyboardInterrupt:
log.debug("Keyboard interrupt received, stopping...")
for client in Client.cache.values():
client.stop()
db_session.commit()
loop.run_until_complete(server.stop())
sys.exit(0)

View file

@ -37,15 +37,21 @@ class Client:
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
self.cache[self.id] = self
self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance,
mxid=self.id, base_url=self.homeserver,
self.log = log.getChild(self.id)
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
log=log.getChild(self.id))
log=self.log, loop=self.loop, store=self.db_instance)
if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
def start(self) -> None:
asyncio.ensure_future(self.client.start(), loop=self.loop)
asyncio.ensure_future(self._start(), loop=self.loop)
async def _start(self) -> None:
try:
await self.client.start()
except Exception:
self.log.exception("Fail")
def stop(self) -> None:
self.client.stop()
@ -64,6 +70,10 @@ class Client:
def all(cls) -> List['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()]
async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)
# region Properties
@property
@ -72,7 +82,7 @@ class Client:
@property
def homeserver(self) -> str:
return self.db_instance.id
return self.db_instance.homeserver
@property
def access_token(self) -> str:
@ -139,12 +149,9 @@ class Client:
# endregion
async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)
def init(loop: asyncio.AbstractEventLoop) -> None:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
for client in Client.all():
client.start()

View file

@ -118,6 +118,12 @@ class zipimporter:
self._files = _read_directory(self.archive)
_zip_directory_cache[self.archive] = self._files
def remove_cache(self):
try:
del _zip_directory_cache[self.archive]
except KeyError:
pass
# Check whether we can satisfy the import of the module named by
# 'fullname', or whether it could be a portion of a namespace
# package. Return self if we can load it, a string containing the

View file

@ -13,11 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TypeVar, Type, Dict
from typing import TypeVar, Type, Dict, Set, TYPE_CHECKING
from abc import ABC, abstractmethod
from ..plugin_base import Plugin
if TYPE_CHECKING:
from ..plugin import PluginInstance
PluginClass = TypeVar("PluginClass", bound=Plugin)
@ -28,9 +31,17 @@ class IDConflictError(Exception):
class PluginLoader(ABC):
id_cache: Dict[str, 'PluginLoader'] = {}
references: Set['PluginInstance']
id: str
version: str
def __init__(self):
self.references = set()
@classmethod
def find(cls, plugin_id: str) -> 'PluginLoader':
return cls.id_cache[plugin_id]
@property
@abstractmethod
def source(self) -> str:

View file

@ -15,8 +15,10 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Type
from zipfile import ZipFile, BadZipFile
import sys
import configparser
import logging
import sys
import os
from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin
@ -29,6 +31,7 @@ class MaubotZipImportError(Exception):
class ZippedPluginLoader(PluginLoader):
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
log = logging.getLogger("maubot.loader.zip")
path: str
id: str
@ -40,9 +43,11 @@ class ZippedPluginLoader(PluginLoader):
_importer: zipimporter
def __init__(self, path: str) -> None:
super().__init__()
self.path = path
self.id = None
self._loaded = None
self._importer = None
self._load_meta()
self._run_preload_checks(self._get_importer())
try:
@ -52,6 +57,7 @@ class ZippedPluginLoader(PluginLoader):
pass
self.path_cache[self.path] = self
self.id_cache[self.id] = self
self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
@classmethod
def get(cls, path: str) -> 'ZippedPluginLoader':
@ -68,7 +74,7 @@ class ZippedPluginLoader(PluginLoader):
return ("<ZippedPlugin "
f"path='{self.path}' "
f"id='{self.id}' "
f"loaded={self._loaded}>")
f"loaded={self._loaded is not None}>")
def _load_meta(self) -> None:
try:
@ -100,10 +106,11 @@ class ZippedPluginLoader(PluginLoader):
def _get_importer(self, reset_cache: bool = False) -> zipimporter:
try:
importer = zipimporter(self.path)
if not self._importer:
self._importer = zipimporter(self.path)
if reset_cache:
importer.reset_cache()
return importer
self._importer.reset_cache()
return self._importer
except ZipImportError as e:
raise MaubotZipImportError("File not found or not a maubot plugin") from e
@ -127,6 +134,8 @@ class ZippedPluginLoader(PluginLoader):
return self._loaded
importer = self._get_importer(reset_cache=reset_cache)
self._run_preload_checks(importer)
if reset_cache:
self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
for module in self.modules:
importer.load_module(module)
main_mod = sys.modules[self.main_module]
@ -134,6 +143,7 @@ class ZippedPluginLoader(PluginLoader):
if not issubclass(plugin, Plugin):
raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin")
self._loaded = plugin
self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
return plugin
def reload(self) -> Type[PluginClass]:
@ -144,6 +154,8 @@ class ZippedPluginLoader(PluginLoader):
for name, mod in list(sys.modules.items()):
if getattr(mod, "__file__", "").startswith(self.path):
del sys.modules[name]
self._loaded = None
self.log.debug(f"Unloaded plugin {self.id} at {self.path}")
def destroy(self) -> None:
self.unload()
@ -155,3 +167,24 @@ class ZippedPluginLoader(PluginLoader):
del self.id_cache[self.id]
except KeyError:
pass
self.id = None
self.path = None
self.version = None
self.modules = None
if self._importer:
self._importer.remove_cache()
self._importer = None
self._loaded = None
@classmethod
def load_all(cls, *args: str) -> None:
cls.log.debug("Preloading plugins...")
for directory in args:
for file in os.listdir(directory):
if not file.endswith(".mbp"):
continue
path = os.path.join(directory, file)
try:
ZippedPluginLoader.get(path)
except (MaubotZipImportError, IDConflictError):
cls.log.exception(f"Failed to load plugin at {path}")

View file

@ -13,12 +13,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List
from typing import Dict, List, Optional
import logging
from mautrix.types import UserID
from .db import DBPlugin
from .client import Client
from .loader import PluginLoader
from .plugin_base import Plugin
log = logging.getLogger("maubot.plugin")
@ -27,10 +30,56 @@ class PluginInstance:
cache: Dict[str, 'PluginInstance'] = {}
plugin_directories: List[str] = []
log: logging.Logger
loader: PluginLoader
client: Client
plugin: Plugin
def __init__(self, db_instance: DBPlugin):
self.db_instance = db_instance
self.log = logging.getLogger(f"maubot.plugin.{self.id}")
self.cache[self.id] = self
def load(self) -> None:
try:
self.loader = PluginLoader.find(self.type)
except KeyError:
self.log.error(f"Failed to find loader for type {self.type}")
self.db_instance.enabled = False
return
self.client = Client.get(self.primary_user)
if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}")
self.db_instance.enabled = False
async def start(self) -> None:
self.log.debug(f"Starting...")
cls = self.loader.load()
self.plugin = cls(self.client.client, self.id, self.log)
self.loader.references |= {self}
await self.plugin.start()
async def stop(self) -> None:
self.log.debug("Stopping...")
self.loader.references -= {self}
await self.plugin.stop()
self.plugin = None
@classmethod
def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None
) -> Optional['PluginInstance']:
try:
return cls.cache[instance_id]
except KeyError:
db_instance = db_instance or DBPlugin.query.get(instance_id)
if not db_instance:
return None
return PluginInstance(db_instance)
@classmethod
def all(cls) -> List['PluginInstance']:
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
# region Properties
@property

View file

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TYPE_CHECKING
from logging import Logger
from abc import ABC
if TYPE_CHECKING:
@ -22,9 +23,14 @@ if TYPE_CHECKING:
class Plugin(ABC):
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None:
client: 'MaubotMatrixClient'
id: str
log: Logger
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger) -> None:
self.client = client
self.id = plugin_instance_id
self.log = log
def set_command_spec(self, spec: 'CommandSpec') -> None:
self.client.set_command_spec(self.id, spec)