Change things

This commit is contained in:
Tulir Asokan 2019-03-07 21:35:35 +02:00
parent b3e1f1d4bc
commit d2b145d0bc
4 changed files with 111 additions and 73 deletions

View file

@ -59,8 +59,7 @@ init_zip_loader(config)
db_session = init_db(config) db_session = init_db(config)
clients = init_client_class(db_session, loop) clients = init_client_class(db_session, loop)
management_api = init_mgmt_api(config, loop) management_api = init_mgmt_api(config, loop)
server = MaubotServer(config, loop) server = MaubotServer(management_api, config, loop)
server.app.add_subapp(config["server.base_path"], management_api)
plugins = init_plugin_instance_class(db_session, config, server, loop) plugins = init_plugin_instance_class(db_session, config, server, loop)
for plugin in plugins: for plugin in plugins:

View file

@ -24,7 +24,7 @@ from aiohttp import ClientSession
if TYPE_CHECKING: if TYPE_CHECKING:
from mautrix.util.config import BaseProxyConfig from mautrix.util.config import BaseProxyConfig
from .client import MaubotMatrixClient from .client import MaubotMatrixClient
from .server import PluginWebApp from .plugin_server import PluginWebApp
class Plugin(ABC): class Plugin(ABC):

87
maubot/plugin_server.py Normal file
View file

@ -0,0 +1,87 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Callable, Awaitable
from functools import partial
from aiohttp import web, hdrs
from yarl import URL
Handler = Callable[[web.Request], Awaitable[web.Response]]
Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]]
class PluginWebApp(web.UrlDispatcher):
def __init__(self):
super().__init__()
self._middleware: List[Middleware] = []
def add_middleware(self, middleware: Middleware) -> None:
self._middleware.append(middleware)
def remove_middleware(self, middleware: Middleware) -> None:
self._middleware.remove(middleware)
def clear(self) -> None:
self._resources = []
self._named_resources = {}
self._middleware = []
async def handle(self, request: web.Request) -> web.Response:
match_info = await self.resolve(request)
match_info.freeze()
resp = None
request._match_info = match_info
expect = request.headers.get(hdrs.EXPECT)
if expect:
resp = await match_info.expect_handler(request)
await request.writer.drain()
if resp is None:
handler = match_info.handler
for middleware in self._middleware:
handler = partial(middleware, handler=handler)
resp = await handler(request)
return resp
class PrefixResource(web.Resource):
def __init__(self, prefix, *, name=None):
assert not prefix or prefix.startswith('/'), prefix
assert prefix in ('', '/') or not prefix.endswith('/'), prefix
super().__init__(name=name)
self._prefix = URL.build(path=prefix).raw_path
@property
def canonical(self):
return self._prefix
def get_info(self):
return {'path': self._prefix}
def url_for(self):
return URL.build(path=self._prefix, encoded=True)
def add_prefix(self, prefix):
assert prefix.startswith('/')
assert not prefix.endswith('/')
assert len(prefix) > 1
self._prefix = prefix + self._prefix
def _match(self, path: str) -> dict:
return {} if self.raw_match(path) else None
def raw_match(self, path: str) -> bool:
return path and path.startswith(self._prefix)

View file

@ -13,18 +13,18 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, List, Dict, Callable, Awaitable from typing import Tuple, Dict
from functools import partial
import logging import logging
import asyncio import asyncio
from aiohttp import web, hdrs, URL from aiohttp import web, hdrs
from aiohttp.abc import AbstractAccessLogger from aiohttp.abc import AbstractAccessLogger
import pkg_resources import pkg_resources
from mautrix.api import PathBuilder, Method from mautrix.api import PathBuilder, Method
from .config import Config from .config import Config
from .plugin_server import PrefixResource, PluginWebApp
from .__meta__ import __version__ from .__meta__ import __version__
@ -35,78 +35,19 @@ class AccessLogger(AbstractAccessLogger):
f'in {round(time, 4)}s"') f'in {round(time, 4)}s"')
Handler = Callable[[web.Request], Awaitable[web.Response]]
Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]]
class PluginWebApp(web.UrlDispatcher):
def __init__(self):
super().__init__()
self._middleware: List[Middleware] = []
def add_middleware(self, middleware: Middleware) -> None:
self._middleware.append(middleware)
def remove_middleware(self, middleware: Middleware) -> None:
self._middleware.remove(middleware)
async def handle(self, request: web.Request) -> web.Response:
match_info = await self.resolve(request)
match_info.freeze()
resp = None
request._match_info = match_info
expect = request.headers.get(hdrs.EXPECT)
if expect:
resp = await match_info.expect_handler(request)
await request.writer.drain()
if resp is None:
handler = match_info.handler
for middleware in self._middleware:
handler = partial(middleware, handler=handler)
resp = await handler(request)
return resp
class PrefixResource(web.Resource):
def __init__(self, prefix, *, name=None):
assert not prefix or prefix.startswith('/'), prefix
assert prefix in ('', '/') or not prefix.endswith('/'), prefix
super().__init__(name=name)
self._prefix = URL.build(path=prefix).raw_path
@property
def canonical(self):
return self._prefix
def add_prefix(self, prefix):
assert prefix.startswith('/')
assert not prefix.endswith('/')
assert len(prefix) > 1
self._prefix = prefix + self._prefix
def _match(self, path: str) -> dict:
return {} if self.raw_match(path) else None
def raw_match(self, path: str) -> bool:
return path and path.startswith(self._prefix)
class MaubotServer: class MaubotServer:
log: logging.Logger = logging.getLogger("maubot.server") log: logging.Logger = logging.getLogger("maubot.server")
plugin_routes: Dict[str, PluginWebApp]
def __init__(self, config: Config, loop: asyncio.AbstractEventLoop) -> None: def __init__(self, management_api: web.Application, config: Config,
loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop or asyncio.get_event_loop() self.loop = loop or asyncio.get_event_loop()
self.app = web.Application(loop=self.loop, client_max_size=100 * 1024 * 1024) self.app = web.Application(loop=self.loop, client_max_size=100 * 1024 * 1024)
self.config = config self.config = config
as_path = PathBuilder(config["server.appservice_base_path"]) self.setup_appservice()
self.add_route(Method.PUT, as_path.transactions, self.handle_transaction) self.app.add_subapp(config["server.base_path"], management_api)
self.setup_instance_subapps()
self.plugin_routes: Dict[str, PluginWebApp] = {}
resource = PrefixResource(config["server.plugin_base_path"])
resource.add_route(hdrs.METH_ANY, self.handle_plugin_path)
self.app.router.register_resource(resource)
self.setup_management_ui() self.setup_management_ui()
self.runner = web.AppRunner(self.app, access_log_class=AccessLogger) self.runner = web.AppRunner(self.app, access_log_class=AccessLogger)
@ -114,7 +55,8 @@ class MaubotServer:
async def handle_plugin_path(self, request: web.Request) -> web.Response: async def handle_plugin_path(self, request: web.Request) -> web.Response:
for path, app in self.plugin_routes.items(): for path, app in self.plugin_routes.items():
if request.path.startswith(path): if request.path.startswith(path):
request = request.clone(rel_url=request.path[len(path):]) request = request.clone(
rel_url=request.rel_url.with_path(request.rel_url.path[len(path):]))
return await app.handle(request) return await app.handle(request)
return web.Response(status=404) return web.Response(status=404)
@ -131,10 +73,20 @@ class MaubotServer:
def remove_instance_webapp(self, instance_id: str) -> None: def remove_instance_webapp(self, instance_id: str) -> None:
try: try:
subpath = self.config["server.plugin_base_path"] + instance_id subpath = self.config["server.plugin_base_path"] + instance_id
self.plugin_routes.pop(subpath) self.plugin_routes.pop(subpath).clear()
except KeyError: except KeyError:
return return
def setup_instance_subapps(self) -> None:
self.plugin_routes = {}
resource = PrefixResource(self.config["server.plugin_base_path"].rstrip("/"))
resource.add_route(hdrs.METH_ANY, self.handle_plugin_path)
self.app.router.register_resource(resource)
def setup_appservice(self) -> None:
as_path = PathBuilder(self.config["server.appservice_base_path"])
self.add_route(Method.PUT, as_path.transactions, self.handle_transaction)
def setup_management_ui(self) -> None: def setup_management_ui(self) -> None:
ui_base = self.config["server.ui_base_path"] ui_base = self.config["server.ui_base_path"]
if ui_base == "/": if ui_base == "/":