Try another approach for plugin web apps

This commit is contained in:
Tulir Asokan 2019-03-07 19:57:10 +02:00
parent 3c2d0a9fde
commit b3e1f1d4bc
3 changed files with 79 additions and 20 deletions

View file

@ -35,7 +35,7 @@ from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin
if TYPE_CHECKING:
from .server import MaubotServer
from .server import MaubotServer, PluginWebApp
log = logging.getLogger("maubot.instance")
@ -59,7 +59,7 @@ class PluginInstance:
base_cfg: RecursiveDict[CommentedMap]
inst_db: sql.engine.Engine
inst_db_tables: Dict[str, sql.Table]
inst_webapp: web.Application
inst_webapp: 'PluginWebApp'
inst_webapp_url: str
started: bool

View file

@ -17,7 +17,6 @@ from typing import Type, Optional, TYPE_CHECKING
from abc import ABC
from logging import Logger
from asyncio import AbstractEventLoop
from aiohttp.web import Application
from sqlalchemy.engine.base import Engine
from aiohttp import ClientSession
@ -25,6 +24,7 @@ from aiohttp import ClientSession
if TYPE_CHECKING:
from mautrix.util.config import BaseProxyConfig
from .client import MaubotMatrixClient
from .server import PluginWebApp
class Plugin(ABC):
@ -34,10 +34,12 @@ class Plugin(ABC):
loop: AbstractEventLoop
config: Optional['BaseProxyConfig']
database: Optional[Engine]
webapp: Optional['PluginWebApp']
webapp_url: Optional[str]
def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession,
instance_id: str, log: Logger, config: Optional['BaseProxyConfig'],
database: Optional[Engine], webapp: Optional[Application],
database: Optional[Engine], webapp: Optional['PluginWebApp'],
webapp_url: Optional[str]) -> None:
self.client = client
self.loop = loop

View file

@ -13,11 +13,12 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Dict
from typing import Tuple, List, Dict, Callable, Awaitable
from functools import partial
import logging
import asyncio
from aiohttp import web
from aiohttp import web, hdrs, URL
from aiohttp.abc import AbstractAccessLogger
import pkg_resources
@ -34,6 +35,62 @@ class AccessLogger(AbstractAccessLogger):
f'in {round(time, 4)}s"')
Handler = Callable[[web.Request], Awaitable[web.Response]]
Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]]
class PluginWebApp(web.UrlDispatcher):
def __init__(self):
super().__init__()
self._middleware: List[Middleware] = []
def add_middleware(self, middleware: Middleware) -> None:
self._middleware.append(middleware)
def remove_middleware(self, middleware: Middleware) -> None:
self._middleware.remove(middleware)
async def handle(self, request: web.Request) -> web.Response:
match_info = await self.resolve(request)
match_info.freeze()
resp = None
request._match_info = match_info
expect = request.headers.get(hdrs.EXPECT)
if expect:
resp = await match_info.expect_handler(request)
await request.writer.drain()
if resp is None:
handler = match_info.handler
for middleware in self._middleware:
handler = partial(middleware, handler=handler)
resp = await handler(request)
return resp
class PrefixResource(web.Resource):
def __init__(self, prefix, *, name=None):
assert not prefix or prefix.startswith('/'), prefix
assert prefix in ('', '/') or not prefix.endswith('/'), prefix
super().__init__(name=name)
self._prefix = URL.build(path=prefix).raw_path
@property
def canonical(self):
return self._prefix
def add_prefix(self, prefix):
assert prefix.startswith('/')
assert not prefix.endswith('/')
assert len(prefix) > 1
self._prefix = prefix + self._prefix
def _match(self, path: str) -> dict:
return {} if self.raw_match(path) else None
def raw_match(self, path: str) -> bool:
return path and path.startswith(self._prefix)
class MaubotServer:
log: logging.Logger = logging.getLogger("maubot.server")
@ -45,38 +102,38 @@ class MaubotServer:
as_path = PathBuilder(config["server.appservice_base_path"])
self.add_route(Method.PUT, as_path.transactions, self.handle_transaction)
self.plugin_apps: Dict[str, web.Application] = {}
self.app.router.add_view(config["server.plugin_base_path"], self.handle_plugin_path)
self.plugin_routes: Dict[str, PluginWebApp] = {}
resource = PrefixResource(config["server.plugin_base_path"])
resource.add_route(hdrs.METH_ANY, self.handle_plugin_path)
self.app.router.register_resource(resource)
self.setup_management_ui()
self.runner = web.AppRunner(self.app, access_log_class=AccessLogger)
async def handle_plugin_path(self, request: web.Request) -> web.Response:
for path, app in self.plugin_apps.items():
for path, app in self.plugin_routes.items():
if request.path.startswith(path):
# TODO there's probably a correct way to do these
request._rel_url.path = request._rel_url.path[len(path):]
return await app._handle(request)
request = request.clone(rel_url=request.path[len(path):])
return await app.handle(request)
return web.Response(status=404)
def get_instance_subapp(self, instance_id: str) -> Tuple[web.Application, str]:
subpath = self.config["server.plugin_base_path"].format(id=instance_id)
def get_instance_subapp(self, instance_id: str) -> Tuple[PluginWebApp, str]:
subpath = self.config["server.plugin_base_path"] + instance_id
url = self.config["server.public_url"] + subpath
try:
return self.plugin_apps[subpath], url
return self.plugin_routes[subpath], url
except KeyError:
app = web.Application(loop=self.loop)
self.plugin_apps[subpath] = app
app = PluginWebApp()
self.plugin_routes[subpath] = app
return app, url
def remove_instance_webapp(self, instance_id: str) -> None:
try:
subapp: web.Application = self.plugin_apps.pop(instance_id)
subpath = self.config["server.plugin_base_path"] + instance_id
self.plugin_routes.pop(subpath)
except KeyError:
return
subapp.shutdown()
subapp.cleanup()
def setup_management_ui(self) -> None:
ui_base = self.config["server.ui_base_path"]