Add option to create/update client with mbc auth
This commit is contained in:
parent
8c3e3a3255
commit
85e5ea401c
8 changed files with 234 additions and 110 deletions
|
@ -15,15 +15,41 @@
|
|||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Any, Callable, Union, Optional
|
||||
import functools
|
||||
import inspect
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
|
||||
from prompt_toolkit.validation import Validator
|
||||
from questionary import prompt
|
||||
import click
|
||||
|
||||
from ..base import app
|
||||
from ..config import get_token
|
||||
from .validators import Required, ClickValidator
|
||||
|
||||
|
||||
def with_http(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
return await func(*args, sess=sess, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def with_authenticated_http(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, server: str, **kwargs):
|
||||
server, token = get_token(server)
|
||||
if not token:
|
||||
return
|
||||
async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}) as sess:
|
||||
return await func(*args, sess=sess, server=server, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def command(help: str) -> Callable[[Callable], Callable]:
|
||||
def decorator(func) -> Callable:
|
||||
questions = func.__inquirer_questions__.copy()
|
||||
|
@ -52,7 +78,10 @@ def command(help: str) -> Callable[[Callable], Callable]:
|
|||
if not resp and question_list:
|
||||
return
|
||||
kwargs = {**kwargs, **resp}
|
||||
func(*args, **kwargs)
|
||||
|
||||
res = func(*args, **kwargs)
|
||||
if inspect.isawaitable(res):
|
||||
asyncio.run(res)
|
||||
|
||||
return app.command(help=help)(wrapper)
|
||||
|
||||
|
|
|
@ -13,13 +13,11 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from urllib.parse import quote
|
||||
from urllib.request import urlopen, Request
|
||||
from urllib.error import HTTPError
|
||||
import functools
|
||||
import json
|
||||
|
||||
from colorama import Fore
|
||||
from yarl import URL
|
||||
import aiohttp
|
||||
import click
|
||||
|
||||
from ..config import get_token
|
||||
|
@ -27,8 +25,6 @@ from ..cliq import cliq
|
|||
|
||||
history_count: int = 10
|
||||
|
||||
enc = functools.partial(quote, safe="")
|
||||
|
||||
friendly_errors = {
|
||||
"server_not_found": "Registration target server not found.\n\n"
|
||||
"To log in or register through maubot, you must add the server to the\n"
|
||||
|
@ -37,6 +33,15 @@ friendly_errors = {
|
|||
}
|
||||
|
||||
|
||||
async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
|
||||
url = URL(server) / "_matrix/maubot/v1/client/auth/servers"
|
||||
async with sess.get(url) as resp:
|
||||
data = await resp.json()
|
||||
print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}")
|
||||
for server in data.keys():
|
||||
print(f"* {Fore.CYAN}{server}{Fore.RESET}")
|
||||
|
||||
|
||||
@cliq.command(help="Log into a Matrix account via the Maubot server")
|
||||
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
|
||||
@cliq.option("-u", "--username", help="The username to log in with", required_unless="list")
|
||||
|
@ -46,42 +51,40 @@ friendly_errors = {
|
|||
required=False, prompt=False)
|
||||
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
|
||||
default=False)
|
||||
@click.option("-c", "--update-client", help="Instead of returning the access token, "
|
||||
"create or update a client in maubot using it",
|
||||
is_flag=True, default=False)
|
||||
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
|
||||
def auth(homeserver: str, username: str, password: str, server: str, register: bool, list: bool
|
||||
) -> None:
|
||||
server, token = get_token(server)
|
||||
if not token:
|
||||
return
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
@cliq.with_authenticated_http
|
||||
async def auth(homeserver: str, username: str, password: str, server: str, register: bool,
|
||||
list: bool, update_client: bool, sess: aiohttp.ClientSession) -> None:
|
||||
if list:
|
||||
url = f"{server}/_matrix/maubot/v1/client/auth/servers"
|
||||
with urlopen(Request(url, headers=headers)) as resp_data:
|
||||
resp = json.load(resp_data)
|
||||
print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}")
|
||||
for server in resp.keys():
|
||||
print(f"* {Fore.CYAN}{server}{Fore.RESET}")
|
||||
return
|
||||
await list_servers(server, sess)
|
||||
return
|
||||
endpoint = "register" if register else "login"
|
||||
headers["Content-Type"] = "application/json"
|
||||
url = f"{server}/_matrix/maubot/v1/client/auth/{enc(homeserver)}/{endpoint}"
|
||||
req = Request(url, headers=headers,
|
||||
data=json.dumps({
|
||||
"username": username,
|
||||
"password": password,
|
||||
}).encode("utf-8"))
|
||||
try:
|
||||
with urlopen(req) as resp_data:
|
||||
resp = json.load(resp_data)
|
||||
url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint
|
||||
if update_client:
|
||||
url = url.with_query({"update_client": "true"})
|
||||
req_data = {"username": username, "password": password}
|
||||
|
||||
async with sess.post(url, json=req_data) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
action = "registered" if register else "logged in as"
|
||||
print(f"{Fore.GREEN}Successfully {action} "
|
||||
f"{Fore.CYAN}{resp['user_id']}{Fore.GREEN}.")
|
||||
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{resp['access_token']}{Fore.RESET}")
|
||||
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{resp['device_id']}{Fore.RESET}")
|
||||
except HTTPError as e:
|
||||
try:
|
||||
err_data = json.load(e)
|
||||
error = friendly_errors.get(err_data["errcode"], err_data["error"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
error = str(e)
|
||||
action = "register" if register else "log in"
|
||||
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")
|
||||
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
|
||||
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
|
||||
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
|
||||
elif resp.status in (201, 202):
|
||||
data = await resp.json()
|
||||
action = "created" if resp.status == 201 else "updated"
|
||||
print(f"{Fore.GREEN}Successfully {action} client for "
|
||||
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
|
||||
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
|
||||
else:
|
||||
try:
|
||||
err_data = await resp.json()
|
||||
error = friendly_errors.get(err_data["errcode"], err_data["error"])
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
|
||||
error = await resp.text()
|
||||
action = "register" if register else "log in"
|
||||
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -86,10 +86,8 @@ class Client:
|
|||
log=self.log, loop=self.loop, device_id=self.device_id,
|
||||
sync_store=SyncStoreProxy(self.db_instance),
|
||||
state_store=self.global_state_store)
|
||||
if OlmMachine and self.device_id and self.crypto_db:
|
||||
self.crypto_store = self._make_crypto_store()
|
||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||
self.client.crypto = self.crypto
|
||||
if self.enable_crypto:
|
||||
self._prepare_crypto()
|
||||
else:
|
||||
self.crypto_store = None
|
||||
self.crypto = None
|
||||
|
@ -102,10 +100,15 @@ class Client:
|
|||
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
||||
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
||||
|
||||
def _make_crypto_store(self) -> 'CryptoStore':
|
||||
if self.crypto_db:
|
||||
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
|
||||
raise ValueError("Crypto database not configured")
|
||||
@property
|
||||
def enable_crypto(self) -> bool:
|
||||
return bool(OlmMachine and self.device_id and self.crypto_db)
|
||||
|
||||
def _prepare_crypto(self) -> None:
|
||||
self.crypto_store = PgCryptoStore(account_id=self.id, pickle_key="mau.crypto",
|
||||
db=self.crypto_db)
|
||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||
self.client.crypto = self.crypto
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: Dict[str, Any]) -> None:
|
||||
|
@ -121,6 +124,19 @@ class Client:
|
|||
except Exception:
|
||||
self.log.exception("Failed to start")
|
||||
|
||||
async def _start_crypto(self) -> None:
|
||||
self.log.debug("Enabling end-to-end encryption support")
|
||||
await self.crypto_store.open()
|
||||
crypto_device_id = await self.crypto_store.get_device_id()
|
||||
if crypto_device_id and crypto_device_id != self.device_id:
|
||||
self.log.warning("Mismatching device ID in crypto store and main database, "
|
||||
"resetting encryption")
|
||||
await self.crypto_store.delete()
|
||||
crypto_device_id = None
|
||||
await self.crypto.load()
|
||||
if not crypto_device_id:
|
||||
await self.crypto_store.put_device_id(self.device_id)
|
||||
|
||||
async def _start(self, try_n: Optional[int] = 0) -> None:
|
||||
if not self.enabled:
|
||||
self.log.debug("Not starting disabled client")
|
||||
|
@ -129,7 +145,7 @@ class Client:
|
|||
self.log.warning("Ignoring start() call to started client")
|
||||
return
|
||||
try:
|
||||
user_id = await self.client.whoami()
|
||||
whoami = await self.client.whoami()
|
||||
except MatrixInvalidToken as e:
|
||||
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||
self.db_instance.enabled = False
|
||||
|
@ -143,8 +159,13 @@ class Client:
|
|||
f"retrying in {(try_n + 1) * 10}s: {e}")
|
||||
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
||||
return
|
||||
if user_id != self.id:
|
||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
|
||||
if whoami.user_id != self.id:
|
||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
|
||||
self.db_instance.enabled = False
|
||||
return
|
||||
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
|
||||
self.log.error(f"Device ID mismatch: expected {self.device_id}, "
|
||||
f"but got {whoami.device_id}")
|
||||
self.db_instance.enabled = False
|
||||
return
|
||||
if not self.filter_id:
|
||||
|
@ -167,15 +188,7 @@ class Client:
|
|||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
if self.crypto:
|
||||
self.log.debug("Enabling end-to-end encryption support")
|
||||
await self.crypto_store.open()
|
||||
crypto_device_id = await self.crypto_store.get_device_id()
|
||||
if crypto_device_id and crypto_device_id != self.device_id:
|
||||
self.log.warning("Mismatching device ID in crypto store and main database. "
|
||||
"Encryption may not work.")
|
||||
await self.crypto.load()
|
||||
if not crypto_device_id:
|
||||
await self.crypto_store.put_device_id(self.device_id)
|
||||
await self._start_crypto()
|
||||
self.start_sync()
|
||||
await self._update_remote_profile()
|
||||
self.started = True
|
||||
|
@ -285,23 +298,31 @@ class Client:
|
|||
else:
|
||||
await self._update_remote_profile()
|
||||
|
||||
async def update_access_details(self, access_token: str, homeserver: str) -> None:
|
||||
async def update_access_details(self, access_token: str, homeserver: str,
|
||||
device_id: Optional[str] = None) -> None:
|
||||
if not access_token and not homeserver:
|
||||
return
|
||||
elif access_token == self.access_token and homeserver == self.homeserver:
|
||||
return
|
||||
device_id = device_id or self.device_id
|
||||
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
||||
token=access_token or self.access_token, loop=self.loop,
|
||||
client_session=self.http_client, device_id=self.device_id,
|
||||
device_id=device_id, client_session=self.http_client,
|
||||
log=self.log, state_store=self.global_state_store)
|
||||
mxid = await new_client.whoami()
|
||||
if mxid != self.id:
|
||||
raise ValueError(f"MXID mismatch: {mxid}")
|
||||
whoami = await new_client.whoami()
|
||||
if whoami.user_id != self.id:
|
||||
raise ValueError(f"MXID mismatch: {whoami.user_id}")
|
||||
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
|
||||
new_client.sync_store = SyncStoreProxy(self.db_instance)
|
||||
self.stop_sync()
|
||||
self.client = new_client
|
||||
self.db_instance.homeserver = homeserver
|
||||
self.db_instance.access_token = access_token
|
||||
self.db_instance.device_id = device_id
|
||||
if self.enable_crypto:
|
||||
self._prepare_crypto()
|
||||
await self._start_crypto()
|
||||
self.start_sync()
|
||||
|
||||
async def _update_remote_profile(self) -> None:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -45,10 +45,11 @@ async def get_client(request: web.Request) -> web.Response:
|
|||
async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
|
||||
homeserver = data.get("homeserver", None)
|
||||
access_token = data.get("access_token", None)
|
||||
device_id = data.get("device_id", None)
|
||||
new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token,
|
||||
loop=Client.loop, client_session=Client.http_client)
|
||||
try:
|
||||
mxid = await new_client.whoami()
|
||||
whoami = await new_client.whoami()
|
||||
except MatrixInvalidToken:
|
||||
return resp.bad_client_access_token
|
||||
except MatrixRequestError:
|
||||
|
@ -56,27 +57,31 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
|
|||
except MatrixConnectionError:
|
||||
return resp.bad_client_connection_details
|
||||
if user_id is None:
|
||||
existing_client = Client.get(mxid, None)
|
||||
existing_client = Client.get(whoami.user_id, None)
|
||||
if existing_client is not None:
|
||||
return resp.user_exists
|
||||
elif mxid != user_id:
|
||||
return resp.mxid_mismatch(mxid)
|
||||
db_instance = DBClient(id=mxid, homeserver=homeserver, access_token=access_token,
|
||||
elif whoami.user_id != user_id:
|
||||
return resp.mxid_mismatch(whoami.user_id)
|
||||
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||
return resp.device_id_mismatch(whoami.device_id)
|
||||
db_instance = DBClient(id=whoami.user_id, homeserver=homeserver, access_token=access_token,
|
||||
enabled=data.get("enabled", True), next_batch=SyncToken(""),
|
||||
filter_id=FilterID(""), sync=data.get("sync", True),
|
||||
autojoin=data.get("autojoin", True), online=data.get("online", True),
|
||||
displayname=data.get("displayname", ""),
|
||||
avatar_url=data.get("avatar_url", ""))
|
||||
avatar_url=data.get("avatar_url", ""),
|
||||
device_id=device_id)
|
||||
client = Client(db_instance)
|
||||
client.db_instance.insert()
|
||||
await client.start()
|
||||
return resp.created(client.to_dict())
|
||||
|
||||
|
||||
async def _update_client(client: Client, data: dict) -> web.Response:
|
||||
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
|
||||
try:
|
||||
await client.update_access_details(data.get("access_token", None),
|
||||
data.get("homeserver", None))
|
||||
data.get("homeserver", None),
|
||||
data.get("device_id", None))
|
||||
except MatrixInvalidToken:
|
||||
return resp.bad_client_access_token
|
||||
except MatrixRequestError:
|
||||
|
@ -93,7 +98,16 @@ async def _update_client(client: Client, data: dict) -> web.Response:
|
|||
client.autojoin = data.get("autojoin", client.autojoin)
|
||||
client.online = data.get("online", client.online)
|
||||
client.sync = data.get("sync", client.sync)
|
||||
return resp.updated(client.to_dict())
|
||||
return resp.updated(client.to_dict(), is_login=is_login)
|
||||
|
||||
|
||||
async def _create_or_update_client(user_id: UserID, data: dict, is_login: bool = False
|
||||
) -> web.Response:
|
||||
client = Client.get(user_id, None)
|
||||
if not client:
|
||||
return await _create_client(user_id, data)
|
||||
else:
|
||||
return await _update_client(client, data, is_login=is_login)
|
||||
|
||||
|
||||
@routes.post("/client/new")
|
||||
|
@ -108,15 +122,11 @@ async def create_client(request: web.Request) -> web.Response:
|
|||
@routes.put("/client/{id}")
|
||||
async def update_client(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
try:
|
||||
data = await request.json()
|
||||
except JSONDecodeError:
|
||||
return resp.body_not_json
|
||||
if not client:
|
||||
return await _create_client(user_id, data)
|
||||
else:
|
||||
return await _update_client(client, data)
|
||||
return await _create_or_update_client(user_id, data)
|
||||
|
||||
|
||||
@routes.delete("/client/{id}")
|
||||
|
|
|
@ -25,10 +25,11 @@ from aiohttp import web
|
|||
from mautrix.api import SynapseAdminPath, Method
|
||||
from mautrix.errors import MatrixRequestError
|
||||
from mautrix.client import ClientAPI
|
||||
from mautrix.types import LoginType
|
||||
from mautrix.types import LoginType, LoginResponse
|
||||
|
||||
from .base import routes, get_config, get_loop
|
||||
from .responses import resp
|
||||
from .client import _create_or_update_client, _create_client
|
||||
|
||||
|
||||
def known_homeservers() -> Dict[str, Dict[str, str]]:
|
||||
|
@ -46,6 +47,7 @@ class AuthRequestInfo(NamedTuple):
|
|||
username: str
|
||||
password: str
|
||||
user_type: str
|
||||
update_client: bool
|
||||
|
||||
|
||||
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
|
||||
|
@ -70,15 +72,16 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR
|
|||
secret = server.get("secret")
|
||||
api = ClientAPI(base_url=base_url, loop=get_loop())
|
||||
user_type = body.get("user_type", "bot")
|
||||
return AuthRequestInfo(api, secret, username, password, user_type), None
|
||||
update_client = request.query.get("update_client", "").lower() in ("1", "true", "yes")
|
||||
return AuthRequestInfo(api, secret, username, password, user_type, update_client), None
|
||||
|
||||
|
||||
def generate_mac(secret: str, nonce: str, user: str, password: str, admin: bool = False,
|
||||
def generate_mac(secret: str, nonce: str, username: str, password: str, admin: bool = False,
|
||||
user_type: str = None) -> str:
|
||||
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
|
||||
mac.update(nonce.encode("utf-8"))
|
||||
mac.update(b"\x00")
|
||||
mac.update(user.encode("utf-8"))
|
||||
mac.update(username.encode("utf-8"))
|
||||
mac.update(b"\x00")
|
||||
mac.update(password.encode("utf-8"))
|
||||
mac.update(b"\x00")
|
||||
|
@ -94,28 +97,34 @@ async def register(request: web.Request) -> web.Response:
|
|||
info, err = await read_client_auth_request(request)
|
||||
if err is not None:
|
||||
return err
|
||||
client: ClientAPI
|
||||
client, secret, username, password, user_type = info
|
||||
if not secret:
|
||||
if not info.secret:
|
||||
return resp.registration_secret_not_found
|
||||
path = SynapseAdminPath.v1.register
|
||||
res = await client.api.request(Method.GET, path)
|
||||
res = await info.client.api.request(Method.GET, path)
|
||||
content = {
|
||||
"nonce": res["nonce"],
|
||||
"username": username,
|
||||
"password": password,
|
||||
"username": info.username,
|
||||
"password": info.password,
|
||||
"admin": False,
|
||||
"mac": generate_mac(secret, res["nonce"], username, password, user_type=user_type),
|
||||
"user_type": user_type,
|
||||
"user_type": info.user_type,
|
||||
}
|
||||
content["mac"] = generate_mac(**content, secret=info.secret)
|
||||
try:
|
||||
return web.json_response(await client.api.request(Method.POST, path, content=content))
|
||||
raw_res = await info.client.api.request(Method.POST, path, content=content)
|
||||
except MatrixRequestError as e:
|
||||
return web.json_response({
|
||||
"errcode": e.errcode,
|
||||
"error": e.message,
|
||||
"http_status": e.http_status,
|
||||
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
login_res = LoginResponse.deserialize(raw_res)
|
||||
if info.update_client:
|
||||
return await _create_client(login_res.user_id, {
|
||||
"homeserver": str(info.client.api.base_url),
|
||||
"access_token": login_res.access_token,
|
||||
"device_id": login_res.device_id,
|
||||
})
|
||||
return web.json_response(login_res.serialize())
|
||||
|
||||
|
||||
@routes.post("/client/auth/{server}/login")
|
||||
|
@ -129,9 +138,15 @@ async def login(request: web.Request) -> web.Response:
|
|||
res = await client.login(identifier=info.username, login_type=LoginType.PASSWORD,
|
||||
password=info.password, device_id=f"maubot_{device_id}",
|
||||
initial_device_display_name="Maubot", store_access_token=False)
|
||||
return web.json_response(res.serialize())
|
||||
except MatrixRequestError as e:
|
||||
return web.json_response({
|
||||
"errcode": e.errcode,
|
||||
"error": e.message,
|
||||
}, status=e.http_status)
|
||||
if info.update_client:
|
||||
return await _create_or_update_client(res.user_id, {
|
||||
"homeserver": str(client.api.base_url),
|
||||
"access_token": res.access_token,
|
||||
"device_id": res.device_id,
|
||||
}, is_login=True)
|
||||
return web.json_response(res.serialize())
|
||||
|
|
|
@ -69,6 +69,13 @@ class _Response:
|
|||
"errcode": "mxid_mismatch",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
def device_id_mismatch(self, found: str) -> web.Response:
|
||||
return web.json_response({
|
||||
"error": "The Matrix device ID of the client and the device ID of the access token "
|
||||
f"don't match. Access token is for device {found}",
|
||||
"errcode": "mxid_mismatch",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
@property
|
||||
def pid_mismatch(self) -> web.Response:
|
||||
return web.json_response({
|
||||
|
@ -294,8 +301,9 @@ class _Response:
|
|||
def found(data: dict) -> web.Response:
|
||||
return web.json_response(data, status=HTTPStatus.OK)
|
||||
|
||||
def updated(self, data: dict) -> web.Response:
|
||||
return self.found(data)
|
||||
@staticmethod
|
||||
def updated(data: dict, is_login: bool = False) -> web.Response:
|
||||
return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK)
|
||||
|
||||
def logged_in(self, token: str) -> web.Response:
|
||||
return self.found({
|
||||
|
|
|
@ -366,7 +366,7 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/MatrixClient'
|
||||
responses:
|
||||
200:
|
||||
202:
|
||||
description: Client updated
|
||||
content:
|
||||
application/json:
|
||||
|
@ -454,6 +454,12 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: update_client
|
||||
in: query
|
||||
description: Should maubot store the access details in a Client instead of returning them?
|
||||
required: false
|
||||
schema:
|
||||
type: boolean
|
||||
post:
|
||||
operationId: client_auth_register
|
||||
summary: |
|
||||
|
@ -475,18 +481,29 @@ paths:
|
|||
properties:
|
||||
access_token:
|
||||
type: string
|
||||
example: token_here
|
||||
example: syt_123_456_789
|
||||
user_id:
|
||||
type: string
|
||||
example: '@putkiteippi:maunium.net'
|
||||
home_server:
|
||||
type: string
|
||||
example: maunium.net
|
||||
device_id:
|
||||
type: string
|
||||
example: device_id_here
|
||||
example: maubot_F00BAR12
|
||||
201:
|
||||
description: Client created (when update_client is true)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/MatrixClient'
|
||||
401:
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
409:
|
||||
description: |
|
||||
There is already a client with the user ID of that token.
|
||||
This should usually not happen, because the user ID was just created.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Error'
|
||||
500:
|
||||
$ref: '#/components/responses/MatrixServerError'
|
||||
'/client/auth/{server}/login':
|
||||
|
@ -497,6 +514,12 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: update_client
|
||||
in: query
|
||||
description: Should maubot store the access details in a Client instead of returning them?
|
||||
required: false
|
||||
schema:
|
||||
type: boolean
|
||||
post:
|
||||
operationId: client_auth_login
|
||||
summary: Log in to the given Matrix server via the maubot server
|
||||
|
@ -519,10 +542,22 @@ paths:
|
|||
example: '@putkiteippi:maunium.net'
|
||||
access_token:
|
||||
type: string
|
||||
example: token_here
|
||||
example: syt_123_456_789
|
||||
device_id:
|
||||
type: string
|
||||
example: device_id_here
|
||||
example: maubot_F00BAR12
|
||||
201:
|
||||
description: Client created (when update_client is true)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/MatrixClient'
|
||||
202:
|
||||
description: Client updated (when update_client is true)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/MatrixClient'
|
||||
401:
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
500:
|
||||
|
@ -641,6 +676,9 @@ components:
|
|||
access_token:
|
||||
type: string
|
||||
description: The Matrix access token for this client.
|
||||
device_id:
|
||||
type: string
|
||||
description: The Matrix device ID corresponding to the access token.
|
||||
enabled:
|
||||
type: boolean
|
||||
example: true
|
||||
|
|
|
@ -144,13 +144,13 @@ async def main():
|
|||
|
||||
while True:
|
||||
try:
|
||||
whoami_user_id = await client.whoami()
|
||||
whoami = await client.whoami()
|
||||
except Exception:
|
||||
log.exception("Failed to connect to homeserver, retrying in 10 seconds...")
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
if whoami_user_id != user_id:
|
||||
log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami_user_id}")
|
||||
if whoami.user_id != user_id:
|
||||
log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami.user_id}")
|
||||
sys.exit(1)
|
||||
break
|
||||
|
||||
|
|
Loading…
Reference in a new issue