1
0
Fork 0
forked from vbatts/maubot
maubot/maubot/management/api/client_auth.py
2022-03-25 19:52:14 +02:00

272 lines
8.8 KiB
Python

# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, NamedTuple, Optional, Tuple
from http import HTTPStatus
from json import JSONDecodeError
import asyncio
import hashlib
import hmac
import random
import string
from aiohttp import web
from yarl import URL
from mautrix.api import Method, Path, SynapseAdminPath
from mautrix.client import ClientAPI
from mautrix.errors import MatrixRequestError
from mautrix.types import LoginResponse, LoginType
from .base import get_config, get_loop, routes
from .client import _create_client, _create_or_update_client
from .responses import resp
def known_homeservers() -> Dict[str, Dict[str, str]]:
return get_config()["homeservers"]
@routes.get("/client/auth/servers")
async def get_known_servers(_: web.Request) -> web.Response:
return web.json_response({key: value["url"] for key, value in known_homeservers().items()})
class AuthRequestInfo(NamedTuple):
server_name: str
client: ClientAPI
secret: str
username: str
password: str
user_type: str
device_name: str
update_client: bool
sso: bool
truthy_strings = ("1", "true", "yes")
async def read_client_auth_request(
request: web.Request,
) -> Tuple[Optional[AuthRequestInfo], Optional[web.Response]]:
server_name = request.match_info.get("server", None)
server = known_homeservers().get(server_name, None)
if not server:
return None, resp.server_not_found
try:
body = await request.json()
except JSONDecodeError:
return None, resp.body_not_json
sso = request.query.get("sso", "").lower() in truthy_strings
try:
username = body["username"]
password = body["password"]
except KeyError:
if not sso:
return None, resp.username_or_password_missing
username = password = None
try:
base_url = server["url"]
except KeyError:
return None, resp.invalid_server
return (
AuthRequestInfo(
server_name=server_name,
client=ClientAPI(base_url=base_url, loop=get_loop()),
secret=server.get("secret"),
username=username,
password=password,
user_type=body.get("user_type", "bot"),
device_name=body.get("device_name", "Maubot"),
update_client=request.query.get("update_client", "").lower() in truthy_strings,
sso=sso,
),
None,
)
def generate_mac(
secret: str,
nonce: str,
username: str,
password: str,
admin: bool = False,
user_type: str = None,
) -> str:
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
mac.update(nonce.encode("utf-8"))
mac.update(b"\x00")
mac.update(username.encode("utf-8"))
mac.update(b"\x00")
mac.update(password.encode("utf-8"))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
if user_type is not None:
mac.update(b"\x00")
mac.update(user_type.encode("utf8"))
return mac.hexdigest()
@routes.post("/client/auth/{server}/register")
async def register(request: web.Request) -> web.Response:
req, err = await read_client_auth_request(request)
if err is not None:
return err
if req.sso:
return resp.registration_no_sso
elif not req.secret:
return resp.registration_secret_not_found
path = SynapseAdminPath.v1.register
res = await req.client.api.request(Method.GET, path)
content = {
"nonce": res["nonce"],
"username": req.username,
"password": req.password,
"admin": False,
"user_type": req.user_type,
}
content["mac"] = generate_mac(**content, secret=req.secret)
try:
raw_res = await req.client.api.request(Method.POST, path, content=content)
except MatrixRequestError as e:
return web.json_response(
{
"errcode": e.errcode,
"error": e.message,
"http_status": e.http_status,
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
login_res = LoginResponse.deserialize(raw_res)
if req.update_client:
return await _create_client(
login_res.user_id,
{
"homeserver": str(req.client.api.base_url),
"access_token": login_res.access_token,
"device_id": login_res.device_id,
},
)
return web.json_response(login_res.serialize())
@routes.post("/client/auth/{server}/login")
async def login(request: web.Request) -> web.Response:
req, err = await read_client_auth_request(request)
if err is not None:
return err
if req.sso:
return await _do_sso(req)
else:
return await _do_login(req)
async def _do_sso(req: AuthRequestInfo) -> web.Response:
flows = await req.client.get_login_flows()
if not flows.supports_type(LoginType.SSO):
return resp.sso_not_supported
waiter_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=16))
cfg = get_config()
public_url = (
URL(cfg["server.public_url"])
/ cfg["server.base_path"].lstrip("/")
/ "client/auth_external_sso/complete"
/ waiter_id
)
sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query(
{"redirectUrl": str(public_url)}
)
sso_waiters[waiter_id] = req, get_loop().create_future()
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response:
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
device_id = f"maubot_{device_id}"
try:
if req.sso:
res = await req.client.login(
token=login_token,
login_type=LoginType.TOKEN,
device_id=device_id,
store_access_token=False,
initial_device_display_name=req.device_name,
)
else:
res = await req.client.login(
identifier=req.username,
login_type=LoginType.PASSWORD,
password=req.password,
device_id=device_id,
initial_device_display_name=req.device_name,
store_access_token=False,
)
except MatrixRequestError as e:
return web.json_response(
{
"errcode": e.errcode,
"error": e.message,
},
status=e.http_status,
)
if req.update_client:
return await _create_or_update_client(
res.user_id,
{
"homeserver": str(req.client.api.base_url),
"access_token": res.access_token,
"device_id": res.device_id,
},
is_login=True,
)
return web.json_response(res.serialize())
sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {}
@routes.post("/client/auth/{server}/sso/{id}/wait")
async def wait_sso(request: web.Request) -> web.Response:
waiter_id = request.match_info["id"]
req, fut = sso_waiters[waiter_id]
try:
login_token = await fut
finally:
sso_waiters.pop(waiter_id, None)
return await _do_login(req, login_token)
@routes.get("/client/auth_external_sso/complete/{id}")
async def complete_sso(request: web.Request) -> web.Response:
try:
_, fut = sso_waiters[request.match_info["id"]]
except KeyError:
return web.Response(status=404, text="Invalid session ID\n")
if fut.cancelled():
return web.Response(status=200, text="The login was cancelled from the Maubot client\n")
elif fut.done():
return web.Response(status=200, text="The login token was already received\n")
try:
fut.set_result(request.query["loginToken"])
except KeyError:
return web.Response(status=400, text="Missing loginToken query parameter\n")
except asyncio.InvalidStateError:
return web.Response(status=500, text="Invalid state\n")
return web.Response(
status=200,
text="Login token received, please return to your Maubot client. "
"This tab can be closed.\n",
)