Make it work
This commit is contained in:
parent
dce2771588
commit
1d8de8b5f2
7 changed files with 139 additions and 32 deletions
|
@ -20,13 +20,13 @@ import argparse
|
|||
import asyncio
|
||||
import copy
|
||||
import sys
|
||||
import os
|
||||
|
||||
from .config import Config
|
||||
from .db import Base, init as init_db
|
||||
from .server import MaubotServer
|
||||
from .client import Client, init as init_client
|
||||
from .loader import ZippedPluginLoader, MaubotZipImportError
|
||||
from .loader import ZippedPluginLoader
|
||||
from .plugin import PluginInstance
|
||||
from .__meta__ import __version__
|
||||
|
||||
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
|
||||
|
@ -57,27 +57,22 @@ loop = asyncio.get_event_loop()
|
|||
init_db(db_session)
|
||||
init_client(loop)
|
||||
server = MaubotServer(config, loop)
|
||||
ZippedPluginLoader.load_all(*config["plugin_directories"])
|
||||
plugins = PluginInstance.all()
|
||||
|
||||
loader_log = logging.getLogger("maubot.loader.zip")
|
||||
loader_log.debug("Preloading plugins...")
|
||||
for directory in config["plugin_directories"]:
|
||||
for file in os.listdir(directory):
|
||||
if not file.endswith(".mbp"):
|
||||
continue
|
||||
path = os.path.join(directory, file)
|
||||
try:
|
||||
loader = ZippedPluginLoader.get(path)
|
||||
loader_log.debug(f"Preloaded plugin {loader.id} from {loader.path}.")
|
||||
except MaubotZipImportError:
|
||||
loader_log.exception(f"Failed to load plugin at {path}.")
|
||||
for plugin in plugins:
|
||||
plugin.load()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(server.start())
|
||||
loop.run_until_complete(asyncio.gather(
|
||||
server.start(),
|
||||
*[plugin.start() for plugin in plugins]))
|
||||
log.debug("Startup actions complete, running forever.")
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
log.debug("Keyboard interrupt received, stopping...")
|
||||
for client in Client.cache.values():
|
||||
client.stop()
|
||||
db_session.commit()
|
||||
loop.run_until_complete(server.stop())
|
||||
sys.exit(0)
|
||||
|
|
|
@ -37,15 +37,21 @@ class Client:
|
|||
def __init__(self, db_instance: DBClient) -> None:
|
||||
self.db_instance = db_instance
|
||||
self.cache[self.id] = self
|
||||
self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance,
|
||||
mxid=self.id, base_url=self.homeserver,
|
||||
self.log = log.getChild(self.id)
|
||||
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||
token=self.access_token, client_session=self.http_client,
|
||||
log=log.getChild(self.id))
|
||||
log=self.log, loop=self.loop, store=self.db_instance)
|
||||
if self.autojoin:
|
||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||
|
||||
def start(self) -> None:
|
||||
asyncio.ensure_future(self.client.start(), loop=self.loop)
|
||||
asyncio.ensure_future(self._start(), loop=self.loop)
|
||||
|
||||
async def _start(self) -> None:
|
||||
try:
|
||||
await self.client.start()
|
||||
except Exception:
|
||||
self.log.exception("Fail")
|
||||
|
||||
def stop(self) -> None:
|
||||
self.client.stop()
|
||||
|
@ -64,6 +70,10 @@ class Client:
|
|||
def all(cls) -> List['Client']:
|
||||
return [cls.get(user.id, user) for user in DBClient.query.all()]
|
||||
|
||||
async def _handle_invite(self, evt: StateEvent) -> None:
|
||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||
await self.client.join_room_by_id(evt.room_id)
|
||||
|
||||
# region Properties
|
||||
|
||||
@property
|
||||
|
@ -72,7 +82,7 @@ class Client:
|
|||
|
||||
@property
|
||||
def homeserver(self) -> str:
|
||||
return self.db_instance.id
|
||||
return self.db_instance.homeserver
|
||||
|
||||
@property
|
||||
def access_token(self) -> str:
|
||||
|
@ -139,12 +149,9 @@ class Client:
|
|||
|
||||
# endregion
|
||||
|
||||
async def _handle_invite(self, evt: StateEvent) -> None:
|
||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||
await self.client.join_room_by_id(evt.room_id)
|
||||
|
||||
|
||||
def init(loop: asyncio.AbstractEventLoop) -> None:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
for client in Client.all():
|
||||
client.start()
|
||||
|
|
|
@ -118,6 +118,12 @@ class zipimporter:
|
|||
self._files = _read_directory(self.archive)
|
||||
_zip_directory_cache[self.archive] = self._files
|
||||
|
||||
def remove_cache(self):
|
||||
try:
|
||||
del _zip_directory_cache[self.archive]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Check whether we can satisfy the import of the module named by
|
||||
# 'fullname', or whether it could be a portion of a namespace
|
||||
# package. Return self if we can load it, a string containing the
|
||||
|
|
|
@ -13,11 +13,14 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import TypeVar, Type, Dict
|
||||
from typing import TypeVar, Type, Dict, Set, TYPE_CHECKING
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..plugin_base import Plugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..plugin import PluginInstance
|
||||
|
||||
PluginClass = TypeVar("PluginClass", bound=Plugin)
|
||||
|
||||
|
||||
|
@ -28,9 +31,17 @@ class IDConflictError(Exception):
|
|||
class PluginLoader(ABC):
|
||||
id_cache: Dict[str, 'PluginLoader'] = {}
|
||||
|
||||
references: Set['PluginInstance']
|
||||
id: str
|
||||
version: str
|
||||
|
||||
def __init__(self):
|
||||
self.references = set()
|
||||
|
||||
@classmethod
|
||||
def find(cls, plugin_id: str) -> 'PluginLoader':
|
||||
return cls.id_cache[plugin_id]
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source(self) -> str:
|
||||
|
|
|
@ -15,8 +15,10 @@
|
|||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List, Type
|
||||
from zipfile import ZipFile, BadZipFile
|
||||
import sys
|
||||
import configparser
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
from ..lib.zipimport import zipimporter, ZipImportError
|
||||
from ..plugin_base import Plugin
|
||||
|
@ -29,6 +31,7 @@ class MaubotZipImportError(Exception):
|
|||
|
||||
class ZippedPluginLoader(PluginLoader):
|
||||
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
|
||||
log = logging.getLogger("maubot.loader.zip")
|
||||
|
||||
path: str
|
||||
id: str
|
||||
|
@ -40,9 +43,11 @@ class ZippedPluginLoader(PluginLoader):
|
|||
_importer: zipimporter
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.id = None
|
||||
self._loaded = None
|
||||
self._importer = None
|
||||
self._load_meta()
|
||||
self._run_preload_checks(self._get_importer())
|
||||
try:
|
||||
|
@ -52,6 +57,7 @@ class ZippedPluginLoader(PluginLoader):
|
|||
pass
|
||||
self.path_cache[self.path] = self
|
||||
self.id_cache[self.id] = self
|
||||
self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
|
||||
|
||||
@classmethod
|
||||
def get(cls, path: str) -> 'ZippedPluginLoader':
|
||||
|
@ -68,7 +74,7 @@ class ZippedPluginLoader(PluginLoader):
|
|||
return ("<ZippedPlugin "
|
||||
f"path='{self.path}' "
|
||||
f"id='{self.id}' "
|
||||
f"loaded={self._loaded}>")
|
||||
f"loaded={self._loaded is not None}>")
|
||||
|
||||
def _load_meta(self) -> None:
|
||||
try:
|
||||
|
@ -100,10 +106,11 @@ class ZippedPluginLoader(PluginLoader):
|
|||
|
||||
def _get_importer(self, reset_cache: bool = False) -> zipimporter:
|
||||
try:
|
||||
importer = zipimporter(self.path)
|
||||
if not self._importer:
|
||||
self._importer = zipimporter(self.path)
|
||||
if reset_cache:
|
||||
importer.reset_cache()
|
||||
return importer
|
||||
self._importer.reset_cache()
|
||||
return self._importer
|
||||
except ZipImportError as e:
|
||||
raise MaubotZipImportError("File not found or not a maubot plugin") from e
|
||||
|
||||
|
@ -127,6 +134,8 @@ class ZippedPluginLoader(PluginLoader):
|
|||
return self._loaded
|
||||
importer = self._get_importer(reset_cache=reset_cache)
|
||||
self._run_preload_checks(importer)
|
||||
if reset_cache:
|
||||
self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
|
||||
for module in self.modules:
|
||||
importer.load_module(module)
|
||||
main_mod = sys.modules[self.main_module]
|
||||
|
@ -134,6 +143,7 @@ class ZippedPluginLoader(PluginLoader):
|
|||
if not issubclass(plugin, Plugin):
|
||||
raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin")
|
||||
self._loaded = plugin
|
||||
self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
|
||||
return plugin
|
||||
|
||||
def reload(self) -> Type[PluginClass]:
|
||||
|
@ -144,6 +154,8 @@ class ZippedPluginLoader(PluginLoader):
|
|||
for name, mod in list(sys.modules.items()):
|
||||
if getattr(mod, "__file__", "").startswith(self.path):
|
||||
del sys.modules[name]
|
||||
self._loaded = None
|
||||
self.log.debug(f"Unloaded plugin {self.id} at {self.path}")
|
||||
|
||||
def destroy(self) -> None:
|
||||
self.unload()
|
||||
|
@ -155,3 +167,24 @@ class ZippedPluginLoader(PluginLoader):
|
|||
del self.id_cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.id = None
|
||||
self.path = None
|
||||
self.version = None
|
||||
self.modules = None
|
||||
if self._importer:
|
||||
self._importer.remove_cache()
|
||||
self._importer = None
|
||||
self._loaded = None
|
||||
|
||||
@classmethod
|
||||
def load_all(cls, *args: str) -> None:
|
||||
cls.log.debug("Preloading plugins...")
|
||||
for directory in args:
|
||||
for file in os.listdir(directory):
|
||||
if not file.endswith(".mbp"):
|
||||
continue
|
||||
path = os.path.join(directory, file)
|
||||
try:
|
||||
ZippedPluginLoader.get(path)
|
||||
except (MaubotZipImportError, IDConflictError):
|
||||
cls.log.exception(f"Failed to load plugin at {path}")
|
||||
|
|
|
@ -13,12 +13,15 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from mautrix.types import UserID
|
||||
|
||||
from .db import DBPlugin
|
||||
from .client import Client
|
||||
from .loader import PluginLoader
|
||||
from .plugin_base import Plugin
|
||||
|
||||
log = logging.getLogger("maubot.plugin")
|
||||
|
||||
|
@ -27,10 +30,56 @@ class PluginInstance:
|
|||
cache: Dict[str, 'PluginInstance'] = {}
|
||||
plugin_directories: List[str] = []
|
||||
|
||||
log: logging.Logger
|
||||
loader: PluginLoader
|
||||
client: Client
|
||||
plugin: Plugin
|
||||
|
||||
def __init__(self, db_instance: DBPlugin):
|
||||
self.db_instance = db_instance
|
||||
self.log = logging.getLogger(f"maubot.plugin.{self.id}")
|
||||
self.cache[self.id] = self
|
||||
|
||||
def load(self) -> None:
|
||||
try:
|
||||
self.loader = PluginLoader.find(self.type)
|
||||
except KeyError:
|
||||
self.log.error(f"Failed to find loader for type {self.type}")
|
||||
self.db_instance.enabled = False
|
||||
return
|
||||
self.client = Client.get(self.primary_user)
|
||||
if not self.client:
|
||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||
self.db_instance.enabled = False
|
||||
|
||||
async def start(self) -> None:
|
||||
self.log.debug(f"Starting...")
|
||||
cls = self.loader.load()
|
||||
self.plugin = cls(self.client.client, self.id, self.log)
|
||||
self.loader.references |= {self}
|
||||
await self.plugin.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.log.debug("Stopping...")
|
||||
self.loader.references -= {self}
|
||||
await self.plugin.stop()
|
||||
self.plugin = None
|
||||
|
||||
@classmethod
|
||||
def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None
|
||||
) -> Optional['PluginInstance']:
|
||||
try:
|
||||
return cls.cache[instance_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBPlugin.query.get(instance_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return PluginInstance(db_instance)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> List['PluginInstance']:
|
||||
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
|
||||
|
||||
# region Properties
|
||||
|
||||
@property
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import TYPE_CHECKING
|
||||
from logging import Logger
|
||||
from abc import ABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -22,9 +23,14 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class Plugin(ABC):
|
||||
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str) -> None:
|
||||
client: 'MaubotMatrixClient'
|
||||
id: str
|
||||
log: Logger
|
||||
|
||||
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger) -> None:
|
||||
self.client = client
|
||||
self.id = plugin_instance_id
|
||||
self.log = log
|
||||
|
||||
def set_command_spec(self, spec: 'CommandSpec') -> None:
|
||||
self.client.set_command_spec(self.id, spec)
|
||||
|
|
Loading…
Reference in a new issue