Try another approach for plugin web apps

This commit is contained in:
Tulir Asokan 2019-03-07 19:57:10 +02:00
parent 3c2d0a9fde
commit b3e1f1d4bc
3 changed files with 79 additions and 20 deletions

View file

@ -35,7 +35,7 @@ from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin from .plugin_base import Plugin
if TYPE_CHECKING: if TYPE_CHECKING:
from .server import MaubotServer from .server import MaubotServer, PluginWebApp
log = logging.getLogger("maubot.instance") log = logging.getLogger("maubot.instance")
@ -59,7 +59,7 @@ class PluginInstance:
base_cfg: RecursiveDict[CommentedMap] base_cfg: RecursiveDict[CommentedMap]
inst_db: sql.engine.Engine inst_db: sql.engine.Engine
inst_db_tables: Dict[str, sql.Table] inst_db_tables: Dict[str, sql.Table]
inst_webapp: web.Application inst_webapp: 'PluginWebApp'
inst_webapp_url: str inst_webapp_url: str
started: bool started: bool

View file

@ -17,7 +17,6 @@ from typing import Type, Optional, TYPE_CHECKING
from abc import ABC from abc import ABC
from logging import Logger from logging import Logger
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from aiohttp.web import Application
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from aiohttp import ClientSession from aiohttp import ClientSession
@ -25,6 +24,7 @@ from aiohttp import ClientSession
if TYPE_CHECKING: if TYPE_CHECKING:
from mautrix.util.config import BaseProxyConfig from mautrix.util.config import BaseProxyConfig
from .client import MaubotMatrixClient from .client import MaubotMatrixClient
from .server import PluginWebApp
class Plugin(ABC): class Plugin(ABC):
@ -34,10 +34,12 @@ class Plugin(ABC):
loop: AbstractEventLoop loop: AbstractEventLoop
config: Optional['BaseProxyConfig'] config: Optional['BaseProxyConfig']
database: Optional[Engine] database: Optional[Engine]
webapp: Optional['PluginWebApp']
webapp_url: Optional[str]
def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession, def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession,
instance_id: str, log: Logger, config: Optional['BaseProxyConfig'], instance_id: str, log: Logger, config: Optional['BaseProxyConfig'],
database: Optional[Engine], webapp: Optional[Application], database: Optional[Engine], webapp: Optional['PluginWebApp'],
webapp_url: Optional[str]) -> None: webapp_url: Optional[str]) -> None:
self.client = client self.client = client
self.loop = loop self.loop = loop

View file

@ -13,11 +13,12 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Dict from typing import Tuple, List, Dict, Callable, Awaitable
from functools import partial
import logging import logging
import asyncio import asyncio
from aiohttp import web from aiohttp import web, hdrs, URL
from aiohttp.abc import AbstractAccessLogger from aiohttp.abc import AbstractAccessLogger
import pkg_resources import pkg_resources
@ -34,6 +35,62 @@ class AccessLogger(AbstractAccessLogger):
f'in {round(time, 4)}s"') f'in {round(time, 4)}s"')
Handler = Callable[[web.Request], Awaitable[web.Response]]
Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]]
class PluginWebApp(web.UrlDispatcher):
def __init__(self):
super().__init__()
self._middleware: List[Middleware] = []
def add_middleware(self, middleware: Middleware) -> None:
self._middleware.append(middleware)
def remove_middleware(self, middleware: Middleware) -> None:
self._middleware.remove(middleware)
async def handle(self, request: web.Request) -> web.Response:
match_info = await self.resolve(request)
match_info.freeze()
resp = None
request._match_info = match_info
expect = request.headers.get(hdrs.EXPECT)
if expect:
resp = await match_info.expect_handler(request)
await request.writer.drain()
if resp is None:
handler = match_info.handler
for middleware in self._middleware:
handler = partial(middleware, handler=handler)
resp = await handler(request)
return resp
class PrefixResource(web.Resource):
def __init__(self, prefix, *, name=None):
assert not prefix or prefix.startswith('/'), prefix
assert prefix in ('', '/') or not prefix.endswith('/'), prefix
super().__init__(name=name)
self._prefix = URL.build(path=prefix).raw_path
@property
def canonical(self):
return self._prefix
def add_prefix(self, prefix):
assert prefix.startswith('/')
assert not prefix.endswith('/')
assert len(prefix) > 1
self._prefix = prefix + self._prefix
def _match(self, path: str) -> dict:
return {} if self.raw_match(path) else None
def raw_match(self, path: str) -> bool:
return path and path.startswith(self._prefix)
class MaubotServer: class MaubotServer:
log: logging.Logger = logging.getLogger("maubot.server") log: logging.Logger = logging.getLogger("maubot.server")
@ -45,38 +102,38 @@ class MaubotServer:
as_path = PathBuilder(config["server.appservice_base_path"]) as_path = PathBuilder(config["server.appservice_base_path"])
self.add_route(Method.PUT, as_path.transactions, self.handle_transaction) self.add_route(Method.PUT, as_path.transactions, self.handle_transaction)
self.plugin_apps: Dict[str, web.Application] = {} self.plugin_routes: Dict[str, PluginWebApp] = {}
self.app.router.add_view(config["server.plugin_base_path"], self.handle_plugin_path) resource = PrefixResource(config["server.plugin_base_path"])
resource.add_route(hdrs.METH_ANY, self.handle_plugin_path)
self.app.router.register_resource(resource)
self.setup_management_ui() self.setup_management_ui()
self.runner = web.AppRunner(self.app, access_log_class=AccessLogger) self.runner = web.AppRunner(self.app, access_log_class=AccessLogger)
async def handle_plugin_path(self, request: web.Request) -> web.Response: async def handle_plugin_path(self, request: web.Request) -> web.Response:
for path, app in self.plugin_apps.items(): for path, app in self.plugin_routes.items():
if request.path.startswith(path): if request.path.startswith(path):
# TODO there's probably a correct way to do these request = request.clone(rel_url=request.path[len(path):])
request._rel_url.path = request._rel_url.path[len(path):] return await app.handle(request)
return await app._handle(request)
return web.Response(status=404) return web.Response(status=404)
def get_instance_subapp(self, instance_id: str) -> Tuple[web.Application, str]: def get_instance_subapp(self, instance_id: str) -> Tuple[PluginWebApp, str]:
subpath = self.config["server.plugin_base_path"].format(id=instance_id) subpath = self.config["server.plugin_base_path"] + instance_id
url = self.config["server.public_url"] + subpath url = self.config["server.public_url"] + subpath
try: try:
return self.plugin_apps[subpath], url return self.plugin_routes[subpath], url
except KeyError: except KeyError:
app = web.Application(loop=self.loop) app = PluginWebApp()
self.plugin_apps[subpath] = app self.plugin_routes[subpath] = app
return app, url return app, url
def remove_instance_webapp(self, instance_id: str) -> None: def remove_instance_webapp(self, instance_id: str) -> None:
try: try:
subapp: web.Application = self.plugin_apps.pop(instance_id) subpath = self.config["server.plugin_base_path"] + instance_id
self.plugin_routes.pop(subpath)
except KeyError: except KeyError:
return return
subapp.shutdown()
subapp.cleanup()
def setup_management_ui(self) -> None: def setup_management_ui(self) -> None:
ui_base = self.config["server.ui_base_path"] ui_base = self.config["server.ui_base_path"]