Add option to create/update client with mbc auth

This commit is contained in:
Tulir Asokan 2021-11-19 17:10:51 +02:00
parent 8c3e3a3255
commit 85e5ea401c
8 changed files with 234 additions and 110 deletions

View file

@ -15,15 +15,41 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, Callable, Union, Optional from typing import Any, Callable, Union, Optional
import functools import functools
import inspect
import asyncio
import aiohttp
from prompt_toolkit.validation import Validator from prompt_toolkit.validation import Validator
from questionary import prompt from questionary import prompt
import click import click
from ..base import app from ..base import app
from ..config import get_token
from .validators import Required, ClickValidator from .validators import Required, ClickValidator
def with_http(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
async with aiohttp.ClientSession() as sess:
return await func(*args, sess=sess, **kwargs)
return wrapper
def with_authenticated_http(func):
@functools.wraps(func)
async def wrapper(*args, server: str, **kwargs):
server, token = get_token(server)
if not token:
return
async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}) as sess:
return await func(*args, sess=sess, server=server, **kwargs)
return wrapper
def command(help: str) -> Callable[[Callable], Callable]: def command(help: str) -> Callable[[Callable], Callable]:
def decorator(func) -> Callable: def decorator(func) -> Callable:
questions = func.__inquirer_questions__.copy() questions = func.__inquirer_questions__.copy()
@ -52,7 +78,10 @@ def command(help: str) -> Callable[[Callable], Callable]:
if not resp and question_list: if not resp and question_list:
return return
kwargs = {**kwargs, **resp} kwargs = {**kwargs, **resp}
func(*args, **kwargs)
res = func(*args, **kwargs)
if inspect.isawaitable(res):
asyncio.run(res)
return app.command(help=help)(wrapper) return app.command(help=help)(wrapper)

View file

@ -13,13 +13,11 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from urllib.parse import quote
from urllib.request import urlopen, Request
from urllib.error import HTTPError
import functools
import json import json
from colorama import Fore from colorama import Fore
from yarl import URL
import aiohttp
import click import click
from ..config import get_token from ..config import get_token
@ -27,8 +25,6 @@ from ..cliq import cliq
history_count: int = 10 history_count: int = 10
enc = functools.partial(quote, safe="")
friendly_errors = { friendly_errors = {
"server_not_found": "Registration target server not found.\n\n" "server_not_found": "Registration target server not found.\n\n"
"To log in or register through maubot, you must add the server to the\n" "To log in or register through maubot, you must add the server to the\n"
@ -37,6 +33,15 @@ friendly_errors = {
} }
async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
url = URL(server) / "_matrix/maubot/v1/client/auth/servers"
async with sess.get(url) as resp:
data = await resp.json()
print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}")
for server in data.keys():
print(f"* {Fore.CYAN}{server}{Fore.RESET}")
@cliq.command(help="Log into a Matrix account via the Maubot server") @cliq.command(help="Log into a Matrix account via the Maubot server")
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list") @cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
@cliq.option("-u", "--username", help="The username to log in with", required_unless="list") @cliq.option("-u", "--username", help="The username to log in with", required_unless="list")
@ -46,42 +51,40 @@ friendly_errors = {
required=False, prompt=False) required=False, prompt=False)
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True, @click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
default=False) default=False)
@click.option("-c", "--update-client", help="Instead of returning the access token, "
"create or update a client in maubot using it",
is_flag=True, default=False)
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False) @click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
def auth(homeserver: str, username: str, password: str, server: str, register: bool, list: bool @cliq.with_authenticated_http
) -> None: async def auth(homeserver: str, username: str, password: str, server: str, register: bool,
server, token = get_token(server) list: bool, update_client: bool, sess: aiohttp.ClientSession) -> None:
if not token:
return
headers = {"Authorization": f"Bearer {token}"}
if list: if list:
url = f"{server}/_matrix/maubot/v1/client/auth/servers" await list_servers(server, sess)
with urlopen(Request(url, headers=headers)) as resp_data:
resp = json.load(resp_data)
print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}")
for server in resp.keys():
print(f"* {Fore.CYAN}{server}{Fore.RESET}")
return return
endpoint = "register" if register else "login" endpoint = "register" if register else "login"
headers["Content-Type"] = "application/json" url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint
url = f"{server}/_matrix/maubot/v1/client/auth/{enc(homeserver)}/{endpoint}" if update_client:
req = Request(url, headers=headers, url = url.with_query({"update_client": "true"})
data=json.dumps({ req_data = {"username": username, "password": password}
"username": username,
"password": password, async with sess.post(url, json=req_data) as resp:
}).encode("utf-8")) if resp.status == 200:
try: data = await resp.json()
with urlopen(req) as resp_data:
resp = json.load(resp_data)
action = "registered" if register else "logged in as" action = "registered" if register else "logged in as"
print(f"{Fore.GREEN}Successfully {action} " print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
f"{Fore.CYAN}{resp['user_id']}{Fore.GREEN}.") print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{resp['access_token']}{Fore.RESET}") print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{resp['device_id']}{Fore.RESET}") elif resp.status in (201, 202):
except HTTPError as e: data = await resp.json()
action = "created" if resp.status == 201 else "updated"
print(f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
else:
try: try:
err_data = json.load(e) err_data = await resp.json()
error = friendly_errors.get(err_data["errcode"], err_data["error"]) error = friendly_errors.get(err_data["errcode"], err_data["error"])
except (json.JSONDecodeError, KeyError): except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
error = str(e) error = await resp.text()
action = "register" if register else "log in" action = "register" if register else "log in"
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -86,10 +86,8 @@ class Client:
log=self.log, loop=self.loop, device_id=self.device_id, log=self.log, loop=self.loop, device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance), sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store) state_store=self.global_state_store)
if OlmMachine and self.device_id and self.crypto_db: if self.enable_crypto:
self.crypto_store = self._make_crypto_store() self._prepare_crypto()
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
else: else:
self.crypto_store = None self.crypto_store = None
self.crypto = None self.crypto = None
@ -102,10 +100,15 @@ class Client:
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False)) self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True)) self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
def _make_crypto_store(self) -> 'CryptoStore': @property
if self.crypto_db: def enable_crypto(self) -> bool:
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db) return bool(OlmMachine and self.device_id and self.crypto_db)
raise ValueError("Crypto database not configured")
def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore(account_id=self.id, pickle_key="mau.crypto",
db=self.crypto_db)
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]: def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
async def handler(data: Dict[str, Any]) -> None: async def handler(data: Dict[str, Any]) -> None:
@ -121,6 +124,19 @@ class Client:
except Exception: except Exception:
self.log.exception("Failed to start") self.log.exception("Failed to start")
async def _start_crypto(self) -> None:
self.log.debug("Enabling end-to-end encryption support")
await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning("Mismatching device ID in crypto store and main database, "
"resetting encryption")
await self.crypto_store.delete()
crypto_device_id = None
await self.crypto.load()
if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id)
async def _start(self, try_n: Optional[int] = 0) -> None: async def _start(self, try_n: Optional[int] = 0) -> None:
if not self.enabled: if not self.enabled:
self.log.debug("Not starting disabled client") self.log.debug("Not starting disabled client")
@ -129,7 +145,7 @@ class Client:
self.log.warning("Ignoring start() call to started client") self.log.warning("Ignoring start() call to started client")
return return
try: try:
user_id = await self.client.whoami() whoami = await self.client.whoami()
except MatrixInvalidToken as e: except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client") self.log.error(f"Invalid token: {e}. Disabling client")
self.db_instance.enabled = False self.db_instance.enabled = False
@ -143,8 +159,13 @@ class Client:
f"retrying in {(try_n + 1) * 10}s: {e}") f"retrying in {(try_n + 1) * 10}s: {e}")
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop) _ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
return return
if user_id != self.id: if whoami.user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}") self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
self.db_instance.enabled = False
return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error(f"Device ID mismatch: expected {self.device_id}, "
f"but got {whoami.device_id}")
self.db_instance.enabled = False self.db_instance.enabled = False
return return
if not self.filter_id: if not self.filter_id:
@ -167,15 +188,7 @@ class Client:
if self.avatar_url != "disable": if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url) await self.client.set_avatar_url(self.avatar_url)
if self.crypto: if self.crypto:
self.log.debug("Enabling end-to-end encryption support") await self._start_crypto()
await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning("Mismatching device ID in crypto store and main database. "
"Encryption may not work.")
await self.crypto.load()
if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id)
self.start_sync() self.start_sync()
await self._update_remote_profile() await self._update_remote_profile()
self.started = True self.started = True
@ -285,23 +298,31 @@ class Client:
else: else:
await self._update_remote_profile() await self._update_remote_profile()
async def update_access_details(self, access_token: str, homeserver: str) -> None: async def update_access_details(self, access_token: str, homeserver: str,
device_id: Optional[str] = None) -> None:
if not access_token and not homeserver: if not access_token and not homeserver:
return return
elif access_token == self.access_token and homeserver == self.homeserver: elif access_token == self.access_token and homeserver == self.homeserver:
return return
device_id = device_id or self.device_id
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver, new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
token=access_token or self.access_token, loop=self.loop, token=access_token or self.access_token, loop=self.loop,
client_session=self.http_client, device_id=self.device_id, device_id=device_id, client_session=self.http_client,
log=self.log, state_store=self.global_state_store) log=self.log, state_store=self.global_state_store)
mxid = await new_client.whoami() whoami = await new_client.whoami()
if mxid != self.id: if whoami.user_id != self.id:
raise ValueError(f"MXID mismatch: {mxid}") raise ValueError(f"MXID mismatch: {whoami.user_id}")
elif whoami.device_id and device_id and whoami.device_id != device_id:
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
new_client.sync_store = SyncStoreProxy(self.db_instance) new_client.sync_store = SyncStoreProxy(self.db_instance)
self.stop_sync() self.stop_sync()
self.client = new_client self.client = new_client
self.db_instance.homeserver = homeserver self.db_instance.homeserver = homeserver
self.db_instance.access_token = access_token self.db_instance.access_token = access_token
self.db_instance.device_id = device_id
if self.enable_crypto:
self._prepare_crypto()
await self._start_crypto()
self.start_sync() self.start_sync()
async def _update_remote_profile(self) -> None: async def _update_remote_profile(self) -> None:

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system. # maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan # Copyright (C) 2021 Tulir Asokan
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -45,10 +45,11 @@ async def get_client(request: web.Request) -> web.Response:
async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response: async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
homeserver = data.get("homeserver", None) homeserver = data.get("homeserver", None)
access_token = data.get("access_token", None) access_token = data.get("access_token", None)
device_id = data.get("device_id", None)
new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token, new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token,
loop=Client.loop, client_session=Client.http_client) loop=Client.loop, client_session=Client.http_client)
try: try:
mxid = await new_client.whoami() whoami = await new_client.whoami()
except MatrixInvalidToken: except MatrixInvalidToken:
return resp.bad_client_access_token return resp.bad_client_access_token
except MatrixRequestError: except MatrixRequestError:
@ -56,27 +57,31 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
except MatrixConnectionError: except MatrixConnectionError:
return resp.bad_client_connection_details return resp.bad_client_connection_details
if user_id is None: if user_id is None:
existing_client = Client.get(mxid, None) existing_client = Client.get(whoami.user_id, None)
if existing_client is not None: if existing_client is not None:
return resp.user_exists return resp.user_exists
elif mxid != user_id: elif whoami.user_id != user_id:
return resp.mxid_mismatch(mxid) return resp.mxid_mismatch(whoami.user_id)
db_instance = DBClient(id=mxid, homeserver=homeserver, access_token=access_token, elif whoami.device_id and device_id and whoami.device_id != device_id:
return resp.device_id_mismatch(whoami.device_id)
db_instance = DBClient(id=whoami.user_id, homeserver=homeserver, access_token=access_token,
enabled=data.get("enabled", True), next_batch=SyncToken(""), enabled=data.get("enabled", True), next_batch=SyncToken(""),
filter_id=FilterID(""), sync=data.get("sync", True), filter_id=FilterID(""), sync=data.get("sync", True),
autojoin=data.get("autojoin", True), online=data.get("online", True), autojoin=data.get("autojoin", True), online=data.get("online", True),
displayname=data.get("displayname", ""), displayname=data.get("displayname", ""),
avatar_url=data.get("avatar_url", "")) avatar_url=data.get("avatar_url", ""),
device_id=device_id)
client = Client(db_instance) client = Client(db_instance)
client.db_instance.insert() client.db_instance.insert()
await client.start() await client.start()
return resp.created(client.to_dict()) return resp.created(client.to_dict())
async def _update_client(client: Client, data: dict) -> web.Response: async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
try: try:
await client.update_access_details(data.get("access_token", None), await client.update_access_details(data.get("access_token", None),
data.get("homeserver", None)) data.get("homeserver", None),
data.get("device_id", None))
except MatrixInvalidToken: except MatrixInvalidToken:
return resp.bad_client_access_token return resp.bad_client_access_token
except MatrixRequestError: except MatrixRequestError:
@ -93,7 +98,16 @@ async def _update_client(client: Client, data: dict) -> web.Response:
client.autojoin = data.get("autojoin", client.autojoin) client.autojoin = data.get("autojoin", client.autojoin)
client.online = data.get("online", client.online) client.online = data.get("online", client.online)
client.sync = data.get("sync", client.sync) client.sync = data.get("sync", client.sync)
return resp.updated(client.to_dict()) return resp.updated(client.to_dict(), is_login=is_login)
async def _create_or_update_client(user_id: UserID, data: dict, is_login: bool = False
) -> web.Response:
client = Client.get(user_id, None)
if not client:
return await _create_client(user_id, data)
else:
return await _update_client(client, data, is_login=is_login)
@routes.post("/client/new") @routes.post("/client/new")
@ -108,15 +122,11 @@ async def create_client(request: web.Request) -> web.Response:
@routes.put("/client/{id}") @routes.put("/client/{id}")
async def update_client(request: web.Request) -> web.Response: async def update_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None) user_id = request.match_info.get("id", None)
client = Client.get(user_id, None)
try: try:
data = await request.json() data = await request.json()
except JSONDecodeError: except JSONDecodeError:
return resp.body_not_json return resp.body_not_json
if not client: return await _create_or_update_client(user_id, data)
return await _create_client(user_id, data)
else:
return await _update_client(client, data)
@routes.delete("/client/{id}") @routes.delete("/client/{id}")

View file

@ -25,10 +25,11 @@ from aiohttp import web
from mautrix.api import SynapseAdminPath, Method from mautrix.api import SynapseAdminPath, Method
from mautrix.errors import MatrixRequestError from mautrix.errors import MatrixRequestError
from mautrix.client import ClientAPI from mautrix.client import ClientAPI
from mautrix.types import LoginType from mautrix.types import LoginType, LoginResponse
from .base import routes, get_config, get_loop from .base import routes, get_config, get_loop
from .responses import resp from .responses import resp
from .client import _create_or_update_client, _create_client
def known_homeservers() -> Dict[str, Dict[str, str]]: def known_homeservers() -> Dict[str, Dict[str, str]]:
@ -46,6 +47,7 @@ class AuthRequestInfo(NamedTuple):
username: str username: str
password: str password: str
user_type: str user_type: str
update_client: bool
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo], async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
@ -70,15 +72,16 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR
secret = server.get("secret") secret = server.get("secret")
api = ClientAPI(base_url=base_url, loop=get_loop()) api = ClientAPI(base_url=base_url, loop=get_loop())
user_type = body.get("user_type", "bot") user_type = body.get("user_type", "bot")
return AuthRequestInfo(api, secret, username, password, user_type), None update_client = request.query.get("update_client", "").lower() in ("1", "true", "yes")
return AuthRequestInfo(api, secret, username, password, user_type, update_client), None
def generate_mac(secret: str, nonce: str, user: str, password: str, admin: bool = False, def generate_mac(secret: str, nonce: str, username: str, password: str, admin: bool = False,
user_type: str = None) -> str: user_type: str = None) -> str:
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1) mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
mac.update(nonce.encode("utf-8")) mac.update(nonce.encode("utf-8"))
mac.update(b"\x00") mac.update(b"\x00")
mac.update(user.encode("utf-8")) mac.update(username.encode("utf-8"))
mac.update(b"\x00") mac.update(b"\x00")
mac.update(password.encode("utf-8")) mac.update(password.encode("utf-8"))
mac.update(b"\x00") mac.update(b"\x00")
@ -94,28 +97,34 @@ async def register(request: web.Request) -> web.Response:
info, err = await read_client_auth_request(request) info, err = await read_client_auth_request(request)
if err is not None: if err is not None:
return err return err
client: ClientAPI if not info.secret:
client, secret, username, password, user_type = info
if not secret:
return resp.registration_secret_not_found return resp.registration_secret_not_found
path = SynapseAdminPath.v1.register path = SynapseAdminPath.v1.register
res = await client.api.request(Method.GET, path) res = await info.client.api.request(Method.GET, path)
content = { content = {
"nonce": res["nonce"], "nonce": res["nonce"],
"username": username, "username": info.username,
"password": password, "password": info.password,
"admin": False, "admin": False,
"mac": generate_mac(secret, res["nonce"], username, password, user_type=user_type), "user_type": info.user_type,
"user_type": user_type,
} }
content["mac"] = generate_mac(**content, secret=info.secret)
try: try:
return web.json_response(await client.api.request(Method.POST, path, content=content)) raw_res = await info.client.api.request(Method.POST, path, content=content)
except MatrixRequestError as e: except MatrixRequestError as e:
return web.json_response({ return web.json_response({
"errcode": e.errcode, "errcode": e.errcode,
"error": e.message, "error": e.message,
"http_status": e.http_status, "http_status": e.http_status,
}, status=HTTPStatus.INTERNAL_SERVER_ERROR) }, status=HTTPStatus.INTERNAL_SERVER_ERROR)
login_res = LoginResponse.deserialize(raw_res)
if info.update_client:
return await _create_client(login_res.user_id, {
"homeserver": str(info.client.api.base_url),
"access_token": login_res.access_token,
"device_id": login_res.device_id,
})
return web.json_response(login_res.serialize())
@routes.post("/client/auth/{server}/login") @routes.post("/client/auth/{server}/login")
@ -129,9 +138,15 @@ async def login(request: web.Request) -> web.Response:
res = await client.login(identifier=info.username, login_type=LoginType.PASSWORD, res = await client.login(identifier=info.username, login_type=LoginType.PASSWORD,
password=info.password, device_id=f"maubot_{device_id}", password=info.password, device_id=f"maubot_{device_id}",
initial_device_display_name="Maubot", store_access_token=False) initial_device_display_name="Maubot", store_access_token=False)
return web.json_response(res.serialize())
except MatrixRequestError as e: except MatrixRequestError as e:
return web.json_response({ return web.json_response({
"errcode": e.errcode, "errcode": e.errcode,
"error": e.message, "error": e.message,
}, status=e.http_status) }, status=e.http_status)
if info.update_client:
return await _create_or_update_client(res.user_id, {
"homeserver": str(client.api.base_url),
"access_token": res.access_token,
"device_id": res.device_id,
}, is_login=True)
return web.json_response(res.serialize())

View file

@ -69,6 +69,13 @@ class _Response:
"errcode": "mxid_mismatch", "errcode": "mxid_mismatch",
}, status=HTTPStatus.BAD_REQUEST) }, status=HTTPStatus.BAD_REQUEST)
def device_id_mismatch(self, found: str) -> web.Response:
return web.json_response({
"error": "The Matrix device ID of the client and the device ID of the access token "
f"don't match. Access token is for device {found}",
"errcode": "mxid_mismatch",
}, status=HTTPStatus.BAD_REQUEST)
@property @property
def pid_mismatch(self) -> web.Response: def pid_mismatch(self) -> web.Response:
return web.json_response({ return web.json_response({
@ -294,8 +301,9 @@ class _Response:
def found(data: dict) -> web.Response: def found(data: dict) -> web.Response:
return web.json_response(data, status=HTTPStatus.OK) return web.json_response(data, status=HTTPStatus.OK)
def updated(self, data: dict) -> web.Response: @staticmethod
return self.found(data) def updated(data: dict, is_login: bool = False) -> web.Response:
return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK)
def logged_in(self, token: str) -> web.Response: def logged_in(self, token: str) -> web.Response:
return self.found({ return self.found({

View file

@ -366,7 +366,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/MatrixClient' $ref: '#/components/schemas/MatrixClient'
responses: responses:
200: 202:
description: Client updated description: Client updated
content: content:
application/json: application/json:
@ -454,6 +454,12 @@ paths:
required: true required: true
schema: schema:
type: string type: string
- name: update_client
in: query
description: Should maubot store the access details in a Client instead of returning them?
required: false
schema:
type: boolean
post: post:
operationId: client_auth_register operationId: client_auth_register
summary: | summary: |
@ -475,18 +481,29 @@ paths:
properties: properties:
access_token: access_token:
type: string type: string
example: token_here example: syt_123_456_789
user_id: user_id:
type: string type: string
example: '@putkiteippi:maunium.net' example: '@putkiteippi:maunium.net'
home_server:
type: string
example: maunium.net
device_id: device_id:
type: string type: string
example: device_id_here example: maubot_F00BAR12
201:
description: Client created (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
401: 401:
$ref: '#/components/responses/Unauthorized' $ref: '#/components/responses/Unauthorized'
409:
description: |
There is already a client with the user ID of that token.
This should usually not happen, because the user ID was just created.
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
500: 500:
$ref: '#/components/responses/MatrixServerError' $ref: '#/components/responses/MatrixServerError'
'/client/auth/{server}/login': '/client/auth/{server}/login':
@ -497,6 +514,12 @@ paths:
required: true required: true
schema: schema:
type: string type: string
- name: update_client
in: query
description: Should maubot store the access details in a Client instead of returning them?
required: false
schema:
type: boolean
post: post:
operationId: client_auth_login operationId: client_auth_login
summary: Log in to the given Matrix server via the maubot server summary: Log in to the given Matrix server via the maubot server
@ -519,10 +542,22 @@ paths:
example: '@putkiteippi:maunium.net' example: '@putkiteippi:maunium.net'
access_token: access_token:
type: string type: string
example: token_here example: syt_123_456_789
device_id: device_id:
type: string type: string
example: device_id_here example: maubot_F00BAR12
201:
description: Client created (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
202:
description: Client updated (when update_client is true)
content:
application/json:
schema:
$ref: '#/components/schemas/MatrixClient'
401: 401:
$ref: '#/components/responses/Unauthorized' $ref: '#/components/responses/Unauthorized'
500: 500:
@ -641,6 +676,9 @@ components:
access_token: access_token:
type: string type: string
description: The Matrix access token for this client. description: The Matrix access token for this client.
device_id:
type: string
description: The Matrix device ID corresponding to the access token.
enabled: enabled:
type: boolean type: boolean
example: true example: true

View file

@ -144,13 +144,13 @@ async def main():
while True: while True:
try: try:
whoami_user_id = await client.whoami() whoami = await client.whoami()
except Exception: except Exception:
log.exception("Failed to connect to homeserver, retrying in 10 seconds...") log.exception("Failed to connect to homeserver, retrying in 10 seconds...")
await asyncio.sleep(10) await asyncio.sleep(10)
continue continue
if whoami_user_id != user_id: if whoami.user_id != user_id:
log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami_user_id}") log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami.user_id}")
sys.exit(1) sys.exit(1)
break break