From 767885cec7d549731e7205d7fdf8d91dabdbde23 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Nov 2018 11:58:58 +0200 Subject: [PATCH] Pass asyncio event loop and http session to plugin instances --- maubot/client.py | 4 ++-- maubot/instance.py | 9 ++++++--- maubot/plugin_base.py | 10 ++++++++-- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/maubot/client.py b/maubot/client.py index 9f4fcc8..dee1d8c 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -32,8 +32,8 @@ log = logging.getLogger("maubot.client") class Client: - log: logging.Logger - loop: asyncio.AbstractEventLoop + log: logging.Logger = None + loop: asyncio.AbstractEventLoop = None cache: Dict[UserID, 'Client'] = {} http_client: ClientSession = None diff --git a/maubot/instance.py b/maubot/instance.py index 83bb85c..924bf5a 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -17,6 +17,7 @@ from typing import Dict, List, Optional from sqlalchemy.orm import Session from ruamel.yaml.comments import CommentedMap from ruamel.yaml import YAML +from asyncio import AbstractEventLoop import logging import io @@ -38,6 +39,7 @@ yaml.indent(4) class PluginInstance: db: Session = None mb_config: Config = None + loop: AbstractEventLoop = None cache: Dict[str, 'PluginInstance'] = {} plugin_directories: List[str] = [] @@ -109,8 +111,8 @@ class PluginInstance: except (FileNotFoundError, KeyError): base_file = None self.config = config_class(self.load_config, lambda: base_file, self.save_config) - self.plugin = cls(self.client.client, self.id, self.log, self.config, - self.mb_config["plugin_directories.db"]) + self.plugin = cls(self.client.client, self.loop, self.client.http_client, self.id, + self.log, self.config, self.mb_config["plugin_directories.db"]) try: await self.plugin.start() except Exception: @@ -178,6 +180,7 @@ class PluginInstance: # endregion -def init(db: Session, config: Config): +def init(db: Session, config: Config, loop: AbstractEventLoop): PluginInstance.db = db PluginInstance.mb_config = config + PluginInstance.loop = loop diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py index 18620c6..9b394f1 100644 --- a/maubot/plugin_base.py +++ b/maubot/plugin_base.py @@ -16,6 +16,8 @@ from typing import Type, Optional, TYPE_CHECKING from logging import Logger from abc import ABC, abstractmethod +from asyncio import AbstractEventLoop +from aiohttp import ClientSession import os.path from sqlalchemy.engine.base import Engine @@ -34,11 +36,15 @@ class Plugin(ABC): client: 'MaubotMatrixClient' id: str log: Logger + loop: AbstractEventLoop config: Optional['BaseProxyConfig'] - def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger, - config: Optional['BaseProxyConfig'], db_base_path: str) -> None: + def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession, + plugin_instance_id: str, log: Logger, config: Optional['BaseProxyConfig'], + db_base_path: str) -> None: self.client = client + self.loop = loop + self.http = http self.id = plugin_instance_id self.log = log self.config = config