Add option to create/update client with mbc auth
This commit is contained in:
parent
8c3e3a3255
commit
85e5ea401c
8 changed files with 234 additions and 110 deletions
|
@ -15,15 +15,41 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Any, Callable, Union, Optional
|
from typing import Any, Callable, Union, Optional
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
from prompt_toolkit.validation import Validator
|
from prompt_toolkit.validation import Validator
|
||||||
from questionary import prompt
|
from questionary import prompt
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from ..base import app
|
from ..base import app
|
||||||
|
from ..config import get_token
|
||||||
from .validators import Required, ClickValidator
|
from .validators import Required, ClickValidator
|
||||||
|
|
||||||
|
|
||||||
|
def with_http(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
async with aiohttp.ClientSession() as sess:
|
||||||
|
return await func(*args, sess=sess, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def with_authenticated_http(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, server: str, **kwargs):
|
||||||
|
server, token = get_token(server)
|
||||||
|
if not token:
|
||||||
|
return
|
||||||
|
async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}) as sess:
|
||||||
|
return await func(*args, sess=sess, server=server, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def command(help: str) -> Callable[[Callable], Callable]:
|
def command(help: str) -> Callable[[Callable], Callable]:
|
||||||
def decorator(func) -> Callable:
|
def decorator(func) -> Callable:
|
||||||
questions = func.__inquirer_questions__.copy()
|
questions = func.__inquirer_questions__.copy()
|
||||||
|
@ -52,7 +78,10 @@ def command(help: str) -> Callable[[Callable], Callable]:
|
||||||
if not resp and question_list:
|
if not resp and question_list:
|
||||||
return
|
return
|
||||||
kwargs = {**kwargs, **resp}
|
kwargs = {**kwargs, **resp}
|
||||||
func(*args, **kwargs)
|
|
||||||
|
res = func(*args, **kwargs)
|
||||||
|
if inspect.isawaitable(res):
|
||||||
|
asyncio.run(res)
|
||||||
|
|
||||||
return app.command(help=help)(wrapper)
|
return app.command(help=help)(wrapper)
|
||||||
|
|
||||||
|
|
|
@ -13,13 +13,11 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from urllib.parse import quote
|
|
||||||
from urllib.request import urlopen, Request
|
|
||||||
from urllib.error import HTTPError
|
|
||||||
import functools
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
|
from yarl import URL
|
||||||
|
import aiohttp
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from ..config import get_token
|
from ..config import get_token
|
||||||
|
@ -27,8 +25,6 @@ from ..cliq import cliq
|
||||||
|
|
||||||
history_count: int = 10
|
history_count: int = 10
|
||||||
|
|
||||||
enc = functools.partial(quote, safe="")
|
|
||||||
|
|
||||||
friendly_errors = {
|
friendly_errors = {
|
||||||
"server_not_found": "Registration target server not found.\n\n"
|
"server_not_found": "Registration target server not found.\n\n"
|
||||||
"To log in or register through maubot, you must add the server to the\n"
|
"To log in or register through maubot, you must add the server to the\n"
|
||||||
|
@ -37,6 +33,15 @@ friendly_errors = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
|
||||||
|
url = URL(server) / "_matrix/maubot/v1/client/auth/servers"
|
||||||
|
async with sess.get(url) as resp:
|
||||||
|
data = await resp.json()
|
||||||
|
print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}")
|
||||||
|
for server in data.keys():
|
||||||
|
print(f"* {Fore.CYAN}{server}{Fore.RESET}")
|
||||||
|
|
||||||
|
|
||||||
@cliq.command(help="Log into a Matrix account via the Maubot server")
|
@cliq.command(help="Log into a Matrix account via the Maubot server")
|
||||||
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
|
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
|
||||||
@cliq.option("-u", "--username", help="The username to log in with", required_unless="list")
|
@cliq.option("-u", "--username", help="The username to log in with", required_unless="list")
|
||||||
|
@ -46,42 +51,40 @@ friendly_errors = {
|
||||||
required=False, prompt=False)
|
required=False, prompt=False)
|
||||||
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
|
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
|
||||||
default=False)
|
default=False)
|
||||||
|
@click.option("-c", "--update-client", help="Instead of returning the access token, "
|
||||||
|
"create or update a client in maubot using it",
|
||||||
|
is_flag=True, default=False)
|
||||||
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
|
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
|
||||||
def auth(homeserver: str, username: str, password: str, server: str, register: bool, list: bool
|
@cliq.with_authenticated_http
|
||||||
) -> None:
|
async def auth(homeserver: str, username: str, password: str, server: str, register: bool,
|
||||||
server, token = get_token(server)
|
list: bool, update_client: bool, sess: aiohttp.ClientSession) -> None:
|
||||||
if not token:
|
|
||||||
return
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
if list:
|
if list:
|
||||||
url = f"{server}/_matrix/maubot/v1/client/auth/servers"
|
await list_servers(server, sess)
|
||||||
with urlopen(Request(url, headers=headers)) as resp_data:
|
return
|
||||||
resp = json.load(resp_data)
|
|
||||||
print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}")
|
|
||||||
for server in resp.keys():
|
|
||||||
print(f"* {Fore.CYAN}{server}{Fore.RESET}")
|
|
||||||
return
|
|
||||||
endpoint = "register" if register else "login"
|
endpoint = "register" if register else "login"
|
||||||
headers["Content-Type"] = "application/json"
|
url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint
|
||||||
url = f"{server}/_matrix/maubot/v1/client/auth/{enc(homeserver)}/{endpoint}"
|
if update_client:
|
||||||
req = Request(url, headers=headers,
|
url = url.with_query({"update_client": "true"})
|
||||||
data=json.dumps({
|
req_data = {"username": username, "password": password}
|
||||||
"username": username,
|
|
||||||
"password": password,
|
async with sess.post(url, json=req_data) as resp:
|
||||||
}).encode("utf-8"))
|
if resp.status == 200:
|
||||||
try:
|
data = await resp.json()
|
||||||
with urlopen(req) as resp_data:
|
|
||||||
resp = json.load(resp_data)
|
|
||||||
action = "registered" if register else "logged in as"
|
action = "registered" if register else "logged in as"
|
||||||
print(f"{Fore.GREEN}Successfully {action} "
|
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
|
||||||
f"{Fore.CYAN}{resp['user_id']}{Fore.GREEN}.")
|
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
|
||||||
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{resp['access_token']}{Fore.RESET}")
|
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
|
||||||
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{resp['device_id']}{Fore.RESET}")
|
elif resp.status in (201, 202):
|
||||||
except HTTPError as e:
|
data = await resp.json()
|
||||||
try:
|
action = "created" if resp.status == 201 else "updated"
|
||||||
err_data = json.load(e)
|
print(f"{Fore.GREEN}Successfully {action} client for "
|
||||||
error = friendly_errors.get(err_data["errcode"], err_data["error"])
|
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
|
||||||
except (json.JSONDecodeError, KeyError):
|
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
|
||||||
error = str(e)
|
else:
|
||||||
action = "register" if register else "log in"
|
try:
|
||||||
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")
|
err_data = await resp.json()
|
||||||
|
error = friendly_errors.get(err_data["errcode"], err_data["error"])
|
||||||
|
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
|
||||||
|
error = await resp.text()
|
||||||
|
action = "register" if register else "log in"
|
||||||
|
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# maubot - A plugin-based Matrix bot system.
|
# maubot - A plugin-based Matrix bot system.
|
||||||
# Copyright (C) 2019 Tulir Asokan
|
# Copyright (C) 2021 Tulir Asokan
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
@ -86,10 +86,8 @@ class Client:
|
||||||
log=self.log, loop=self.loop, device_id=self.device_id,
|
log=self.log, loop=self.loop, device_id=self.device_id,
|
||||||
sync_store=SyncStoreProxy(self.db_instance),
|
sync_store=SyncStoreProxy(self.db_instance),
|
||||||
state_store=self.global_state_store)
|
state_store=self.global_state_store)
|
||||||
if OlmMachine and self.device_id and self.crypto_db:
|
if self.enable_crypto:
|
||||||
self.crypto_store = self._make_crypto_store()
|
self._prepare_crypto()
|
||||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
|
||||||
self.client.crypto = self.crypto
|
|
||||||
else:
|
else:
|
||||||
self.crypto_store = None
|
self.crypto_store = None
|
||||||
self.crypto = None
|
self.crypto = None
|
||||||
|
@ -102,10 +100,15 @@ class Client:
|
||||||
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
||||||
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
||||||
|
|
||||||
def _make_crypto_store(self) -> 'CryptoStore':
|
@property
|
||||||
if self.crypto_db:
|
def enable_crypto(self) -> bool:
|
||||||
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
|
return bool(OlmMachine and self.device_id and self.crypto_db)
|
||||||
raise ValueError("Crypto database not configured")
|
|
||||||
|
def _prepare_crypto(self) -> None:
|
||||||
|
self.crypto_store = PgCryptoStore(account_id=self.id, pickle_key="mau.crypto",
|
||||||
|
db=self.crypto_db)
|
||||||
|
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||||
|
self.client.crypto = self.crypto
|
||||||
|
|
||||||
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
||||||
async def handler(data: Dict[str, Any]) -> None:
|
async def handler(data: Dict[str, Any]) -> None:
|
||||||
|
@ -121,6 +124,19 @@ class Client:
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to start")
|
self.log.exception("Failed to start")
|
||||||
|
|
||||||
|
async def _start_crypto(self) -> None:
|
||||||
|
self.log.debug("Enabling end-to-end encryption support")
|
||||||
|
await self.crypto_store.open()
|
||||||
|
crypto_device_id = await self.crypto_store.get_device_id()
|
||||||
|
if crypto_device_id and crypto_device_id != self.device_id:
|
||||||
|
self.log.warning("Mismatching device ID in crypto store and main database, "
|
||||||
|
"resetting encryption")
|
||||||
|
await self.crypto_store.delete()
|
||||||
|
crypto_device_id = None
|
||||||
|
await self.crypto.load()
|
||||||
|
if not crypto_device_id:
|
||||||
|
await self.crypto_store.put_device_id(self.device_id)
|
||||||
|
|
||||||
async def _start(self, try_n: Optional[int] = 0) -> None:
|
async def _start(self, try_n: Optional[int] = 0) -> None:
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
self.log.debug("Not starting disabled client")
|
self.log.debug("Not starting disabled client")
|
||||||
|
@ -129,7 +145,7 @@ class Client:
|
||||||
self.log.warning("Ignoring start() call to started client")
|
self.log.warning("Ignoring start() call to started client")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
user_id = await self.client.whoami()
|
whoami = await self.client.whoami()
|
||||||
except MatrixInvalidToken as e:
|
except MatrixInvalidToken as e:
|
||||||
self.log.error(f"Invalid token: {e}. Disabling client")
|
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||||
self.db_instance.enabled = False
|
self.db_instance.enabled = False
|
||||||
|
@ -143,8 +159,13 @@ class Client:
|
||||||
f"retrying in {(try_n + 1) * 10}s: {e}")
|
f"retrying in {(try_n + 1) * 10}s: {e}")
|
||||||
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
||||||
return
|
return
|
||||||
if user_id != self.id:
|
if whoami.user_id != self.id:
|
||||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
|
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
|
||||||
|
self.db_instance.enabled = False
|
||||||
|
return
|
||||||
|
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
|
||||||
|
self.log.error(f"Device ID mismatch: expected {self.device_id}, "
|
||||||
|
f"but got {whoami.device_id}")
|
||||||
self.db_instance.enabled = False
|
self.db_instance.enabled = False
|
||||||
return
|
return
|
||||||
if not self.filter_id:
|
if not self.filter_id:
|
||||||
|
@ -167,15 +188,7 @@ class Client:
|
||||||
if self.avatar_url != "disable":
|
if self.avatar_url != "disable":
|
||||||
await self.client.set_avatar_url(self.avatar_url)
|
await self.client.set_avatar_url(self.avatar_url)
|
||||||
if self.crypto:
|
if self.crypto:
|
||||||
self.log.debug("Enabling end-to-end encryption support")
|
await self._start_crypto()
|
||||||
await self.crypto_store.open()
|
|
||||||
crypto_device_id = await self.crypto_store.get_device_id()
|
|
||||||
if crypto_device_id and crypto_device_id != self.device_id:
|
|
||||||
self.log.warning("Mismatching device ID in crypto store and main database. "
|
|
||||||
"Encryption may not work.")
|
|
||||||
await self.crypto.load()
|
|
||||||
if not crypto_device_id:
|
|
||||||
await self.crypto_store.put_device_id(self.device_id)
|
|
||||||
self.start_sync()
|
self.start_sync()
|
||||||
await self._update_remote_profile()
|
await self._update_remote_profile()
|
||||||
self.started = True
|
self.started = True
|
||||||
|
@ -285,23 +298,31 @@ class Client:
|
||||||
else:
|
else:
|
||||||
await self._update_remote_profile()
|
await self._update_remote_profile()
|
||||||
|
|
||||||
async def update_access_details(self, access_token: str, homeserver: str) -> None:
|
async def update_access_details(self, access_token: str, homeserver: str,
|
||||||
|
device_id: Optional[str] = None) -> None:
|
||||||
if not access_token and not homeserver:
|
if not access_token and not homeserver:
|
||||||
return
|
return
|
||||||
elif access_token == self.access_token and homeserver == self.homeserver:
|
elif access_token == self.access_token and homeserver == self.homeserver:
|
||||||
return
|
return
|
||||||
|
device_id = device_id or self.device_id
|
||||||
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
||||||
token=access_token or self.access_token, loop=self.loop,
|
token=access_token or self.access_token, loop=self.loop,
|
||||||
client_session=self.http_client, device_id=self.device_id,
|
device_id=device_id, client_session=self.http_client,
|
||||||
log=self.log, state_store=self.global_state_store)
|
log=self.log, state_store=self.global_state_store)
|
||||||
mxid = await new_client.whoami()
|
whoami = await new_client.whoami()
|
||||||
if mxid != self.id:
|
if whoami.user_id != self.id:
|
||||||
raise ValueError(f"MXID mismatch: {mxid}")
|
raise ValueError(f"MXID mismatch: {whoami.user_id}")
|
||||||
|
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||||
|
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
|
||||||
new_client.sync_store = SyncStoreProxy(self.db_instance)
|
new_client.sync_store = SyncStoreProxy(self.db_instance)
|
||||||
self.stop_sync()
|
self.stop_sync()
|
||||||
self.client = new_client
|
self.client = new_client
|
||||||
self.db_instance.homeserver = homeserver
|
self.db_instance.homeserver = homeserver
|
||||||
self.db_instance.access_token = access_token
|
self.db_instance.access_token = access_token
|
||||||
|
self.db_instance.device_id = device_id
|
||||||
|
if self.enable_crypto:
|
||||||
|
self._prepare_crypto()
|
||||||
|
await self._start_crypto()
|
||||||
self.start_sync()
|
self.start_sync()
|
||||||
|
|
||||||
async def _update_remote_profile(self) -> None:
|
async def _update_remote_profile(self) -> None:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# maubot - A plugin-based Matrix bot system.
|
# maubot - A plugin-based Matrix bot system.
|
||||||
# Copyright (C) 2019 Tulir Asokan
|
# Copyright (C) 2021 Tulir Asokan
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
@ -45,10 +45,11 @@ async def get_client(request: web.Request) -> web.Response:
|
||||||
async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
|
async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
|
||||||
homeserver = data.get("homeserver", None)
|
homeserver = data.get("homeserver", None)
|
||||||
access_token = data.get("access_token", None)
|
access_token = data.get("access_token", None)
|
||||||
|
device_id = data.get("device_id", None)
|
||||||
new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token,
|
new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token,
|
||||||
loop=Client.loop, client_session=Client.http_client)
|
loop=Client.loop, client_session=Client.http_client)
|
||||||
try:
|
try:
|
||||||
mxid = await new_client.whoami()
|
whoami = await new_client.whoami()
|
||||||
except MatrixInvalidToken:
|
except MatrixInvalidToken:
|
||||||
return resp.bad_client_access_token
|
return resp.bad_client_access_token
|
||||||
except MatrixRequestError:
|
except MatrixRequestError:
|
||||||
|
@ -56,27 +57,31 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
|
||||||
except MatrixConnectionError:
|
except MatrixConnectionError:
|
||||||
return resp.bad_client_connection_details
|
return resp.bad_client_connection_details
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
existing_client = Client.get(mxid, None)
|
existing_client = Client.get(whoami.user_id, None)
|
||||||
if existing_client is not None:
|
if existing_client is not None:
|
||||||
return resp.user_exists
|
return resp.user_exists
|
||||||
elif mxid != user_id:
|
elif whoami.user_id != user_id:
|
||||||
return resp.mxid_mismatch(mxid)
|
return resp.mxid_mismatch(whoami.user_id)
|
||||||
db_instance = DBClient(id=mxid, homeserver=homeserver, access_token=access_token,
|
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||||
|
return resp.device_id_mismatch(whoami.device_id)
|
||||||
|
db_instance = DBClient(id=whoami.user_id, homeserver=homeserver, access_token=access_token,
|
||||||
enabled=data.get("enabled", True), next_batch=SyncToken(""),
|
enabled=data.get("enabled", True), next_batch=SyncToken(""),
|
||||||
filter_id=FilterID(""), sync=data.get("sync", True),
|
filter_id=FilterID(""), sync=data.get("sync", True),
|
||||||
autojoin=data.get("autojoin", True), online=data.get("online", True),
|
autojoin=data.get("autojoin", True), online=data.get("online", True),
|
||||||
displayname=data.get("displayname", ""),
|
displayname=data.get("displayname", ""),
|
||||||
avatar_url=data.get("avatar_url", ""))
|
avatar_url=data.get("avatar_url", ""),
|
||||||
|
device_id=device_id)
|
||||||
client = Client(db_instance)
|
client = Client(db_instance)
|
||||||
client.db_instance.insert()
|
client.db_instance.insert()
|
||||||
await client.start()
|
await client.start()
|
||||||
return resp.created(client.to_dict())
|
return resp.created(client.to_dict())
|
||||||
|
|
||||||
|
|
||||||
async def _update_client(client: Client, data: dict) -> web.Response:
|
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
|
||||||
try:
|
try:
|
||||||
await client.update_access_details(data.get("access_token", None),
|
await client.update_access_details(data.get("access_token", None),
|
||||||
data.get("homeserver", None))
|
data.get("homeserver", None),
|
||||||
|
data.get("device_id", None))
|
||||||
except MatrixInvalidToken:
|
except MatrixInvalidToken:
|
||||||
return resp.bad_client_access_token
|
return resp.bad_client_access_token
|
||||||
except MatrixRequestError:
|
except MatrixRequestError:
|
||||||
|
@ -93,7 +98,16 @@ async def _update_client(client: Client, data: dict) -> web.Response:
|
||||||
client.autojoin = data.get("autojoin", client.autojoin)
|
client.autojoin = data.get("autojoin", client.autojoin)
|
||||||
client.online = data.get("online", client.online)
|
client.online = data.get("online", client.online)
|
||||||
client.sync = data.get("sync", client.sync)
|
client.sync = data.get("sync", client.sync)
|
||||||
return resp.updated(client.to_dict())
|
return resp.updated(client.to_dict(), is_login=is_login)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_or_update_client(user_id: UserID, data: dict, is_login: bool = False
|
||||||
|
) -> web.Response:
|
||||||
|
client = Client.get(user_id, None)
|
||||||
|
if not client:
|
||||||
|
return await _create_client(user_id, data)
|
||||||
|
else:
|
||||||
|
return await _update_client(client, data, is_login=is_login)
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/client/new")
|
@routes.post("/client/new")
|
||||||
|
@ -108,15 +122,11 @@ async def create_client(request: web.Request) -> web.Response:
|
||||||
@routes.put("/client/{id}")
|
@routes.put("/client/{id}")
|
||||||
async def update_client(request: web.Request) -> web.Response:
|
async def update_client(request: web.Request) -> web.Response:
|
||||||
user_id = request.match_info.get("id", None)
|
user_id = request.match_info.get("id", None)
|
||||||
client = Client.get(user_id, None)
|
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
return resp.body_not_json
|
return resp.body_not_json
|
||||||
if not client:
|
return await _create_or_update_client(user_id, data)
|
||||||
return await _create_client(user_id, data)
|
|
||||||
else:
|
|
||||||
return await _update_client(client, data)
|
|
||||||
|
|
||||||
|
|
||||||
@routes.delete("/client/{id}")
|
@routes.delete("/client/{id}")
|
||||||
|
|
|
@ -25,10 +25,11 @@ from aiohttp import web
|
||||||
from mautrix.api import SynapseAdminPath, Method
|
from mautrix.api import SynapseAdminPath, Method
|
||||||
from mautrix.errors import MatrixRequestError
|
from mautrix.errors import MatrixRequestError
|
||||||
from mautrix.client import ClientAPI
|
from mautrix.client import ClientAPI
|
||||||
from mautrix.types import LoginType
|
from mautrix.types import LoginType, LoginResponse
|
||||||
|
|
||||||
from .base import routes, get_config, get_loop
|
from .base import routes, get_config, get_loop
|
||||||
from .responses import resp
|
from .responses import resp
|
||||||
|
from .client import _create_or_update_client, _create_client
|
||||||
|
|
||||||
|
|
||||||
def known_homeservers() -> Dict[str, Dict[str, str]]:
|
def known_homeservers() -> Dict[str, Dict[str, str]]:
|
||||||
|
@ -46,6 +47,7 @@ class AuthRequestInfo(NamedTuple):
|
||||||
username: str
|
username: str
|
||||||
password: str
|
password: str
|
||||||
user_type: str
|
user_type: str
|
||||||
|
update_client: bool
|
||||||
|
|
||||||
|
|
||||||
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
|
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
|
||||||
|
@ -70,15 +72,16 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR
|
||||||
secret = server.get("secret")
|
secret = server.get("secret")
|
||||||
api = ClientAPI(base_url=base_url, loop=get_loop())
|
api = ClientAPI(base_url=base_url, loop=get_loop())
|
||||||
user_type = body.get("user_type", "bot")
|
user_type = body.get("user_type", "bot")
|
||||||
return AuthRequestInfo(api, secret, username, password, user_type), None
|
update_client = request.query.get("update_client", "").lower() in ("1", "true", "yes")
|
||||||
|
return AuthRequestInfo(api, secret, username, password, user_type, update_client), None
|
||||||
|
|
||||||
|
|
||||||
def generate_mac(secret: str, nonce: str, user: str, password: str, admin: bool = False,
|
def generate_mac(secret: str, nonce: str, username: str, password: str, admin: bool = False,
|
||||||
user_type: str = None) -> str:
|
user_type: str = None) -> str:
|
||||||
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
|
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
|
||||||
mac.update(nonce.encode("utf-8"))
|
mac.update(nonce.encode("utf-8"))
|
||||||
mac.update(b"\x00")
|
mac.update(b"\x00")
|
||||||
mac.update(user.encode("utf-8"))
|
mac.update(username.encode("utf-8"))
|
||||||
mac.update(b"\x00")
|
mac.update(b"\x00")
|
||||||
mac.update(password.encode("utf-8"))
|
mac.update(password.encode("utf-8"))
|
||||||
mac.update(b"\x00")
|
mac.update(b"\x00")
|
||||||
|
@ -94,28 +97,34 @@ async def register(request: web.Request) -> web.Response:
|
||||||
info, err = await read_client_auth_request(request)
|
info, err = await read_client_auth_request(request)
|
||||||
if err is not None:
|
if err is not None:
|
||||||
return err
|
return err
|
||||||
client: ClientAPI
|
if not info.secret:
|
||||||
client, secret, username, password, user_type = info
|
|
||||||
if not secret:
|
|
||||||
return resp.registration_secret_not_found
|
return resp.registration_secret_not_found
|
||||||
path = SynapseAdminPath.v1.register
|
path = SynapseAdminPath.v1.register
|
||||||
res = await client.api.request(Method.GET, path)
|
res = await info.client.api.request(Method.GET, path)
|
||||||
content = {
|
content = {
|
||||||
"nonce": res["nonce"],
|
"nonce": res["nonce"],
|
||||||
"username": username,
|
"username": info.username,
|
||||||
"password": password,
|
"password": info.password,
|
||||||
"admin": False,
|
"admin": False,
|
||||||
"mac": generate_mac(secret, res["nonce"], username, password, user_type=user_type),
|
"user_type": info.user_type,
|
||||||
"user_type": user_type,
|
|
||||||
}
|
}
|
||||||
|
content["mac"] = generate_mac(**content, secret=info.secret)
|
||||||
try:
|
try:
|
||||||
return web.json_response(await client.api.request(Method.POST, path, content=content))
|
raw_res = await info.client.api.request(Method.POST, path, content=content)
|
||||||
except MatrixRequestError as e:
|
except MatrixRequestError as e:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"errcode": e.errcode,
|
"errcode": e.errcode,
|
||||||
"error": e.message,
|
"error": e.message,
|
||||||
"http_status": e.http_status,
|
"http_status": e.http_status,
|
||||||
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
login_res = LoginResponse.deserialize(raw_res)
|
||||||
|
if info.update_client:
|
||||||
|
return await _create_client(login_res.user_id, {
|
||||||
|
"homeserver": str(info.client.api.base_url),
|
||||||
|
"access_token": login_res.access_token,
|
||||||
|
"device_id": login_res.device_id,
|
||||||
|
})
|
||||||
|
return web.json_response(login_res.serialize())
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/client/auth/{server}/login")
|
@routes.post("/client/auth/{server}/login")
|
||||||
|
@ -129,9 +138,15 @@ async def login(request: web.Request) -> web.Response:
|
||||||
res = await client.login(identifier=info.username, login_type=LoginType.PASSWORD,
|
res = await client.login(identifier=info.username, login_type=LoginType.PASSWORD,
|
||||||
password=info.password, device_id=f"maubot_{device_id}",
|
password=info.password, device_id=f"maubot_{device_id}",
|
||||||
initial_device_display_name="Maubot", store_access_token=False)
|
initial_device_display_name="Maubot", store_access_token=False)
|
||||||
return web.json_response(res.serialize())
|
|
||||||
except MatrixRequestError as e:
|
except MatrixRequestError as e:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"errcode": e.errcode,
|
"errcode": e.errcode,
|
||||||
"error": e.message,
|
"error": e.message,
|
||||||
}, status=e.http_status)
|
}, status=e.http_status)
|
||||||
|
if info.update_client:
|
||||||
|
return await _create_or_update_client(res.user_id, {
|
||||||
|
"homeserver": str(client.api.base_url),
|
||||||
|
"access_token": res.access_token,
|
||||||
|
"device_id": res.device_id,
|
||||||
|
}, is_login=True)
|
||||||
|
return web.json_response(res.serialize())
|
||||||
|
|
|
@ -69,6 +69,13 @@ class _Response:
|
||||||
"errcode": "mxid_mismatch",
|
"errcode": "mxid_mismatch",
|
||||||
}, status=HTTPStatus.BAD_REQUEST)
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
def device_id_mismatch(self, found: str) -> web.Response:
|
||||||
|
return web.json_response({
|
||||||
|
"error": "The Matrix device ID of the client and the device ID of the access token "
|
||||||
|
f"don't match. Access token is for device {found}",
|
||||||
|
"errcode": "mxid_mismatch",
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pid_mismatch(self) -> web.Response:
|
def pid_mismatch(self) -> web.Response:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
|
@ -294,8 +301,9 @@ class _Response:
|
||||||
def found(data: dict) -> web.Response:
|
def found(data: dict) -> web.Response:
|
||||||
return web.json_response(data, status=HTTPStatus.OK)
|
return web.json_response(data, status=HTTPStatus.OK)
|
||||||
|
|
||||||
def updated(self, data: dict) -> web.Response:
|
@staticmethod
|
||||||
return self.found(data)
|
def updated(data: dict, is_login: bool = False) -> web.Response:
|
||||||
|
return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK)
|
||||||
|
|
||||||
def logged_in(self, token: str) -> web.Response:
|
def logged_in(self, token: str) -> web.Response:
|
||||||
return self.found({
|
return self.found({
|
||||||
|
|
|
@ -366,7 +366,7 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/MatrixClient'
|
$ref: '#/components/schemas/MatrixClient'
|
||||||
responses:
|
responses:
|
||||||
200:
|
202:
|
||||||
description: Client updated
|
description: Client updated
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
|
@ -454,6 +454,12 @@ paths:
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
- name: update_client
|
||||||
|
in: query
|
||||||
|
description: Should maubot store the access details in a Client instead of returning them?
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: boolean
|
||||||
post:
|
post:
|
||||||
operationId: client_auth_register
|
operationId: client_auth_register
|
||||||
summary: |
|
summary: |
|
||||||
|
@ -475,18 +481,29 @@ paths:
|
||||||
properties:
|
properties:
|
||||||
access_token:
|
access_token:
|
||||||
type: string
|
type: string
|
||||||
example: token_here
|
example: syt_123_456_789
|
||||||
user_id:
|
user_id:
|
||||||
type: string
|
type: string
|
||||||
example: '@putkiteippi:maunium.net'
|
example: '@putkiteippi:maunium.net'
|
||||||
home_server:
|
|
||||||
type: string
|
|
||||||
example: maunium.net
|
|
||||||
device_id:
|
device_id:
|
||||||
type: string
|
type: string
|
||||||
example: device_id_here
|
example: maubot_F00BAR12
|
||||||
|
201:
|
||||||
|
description: Client created (when update_client is true)
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/MatrixClient'
|
||||||
401:
|
401:
|
||||||
$ref: '#/components/responses/Unauthorized'
|
$ref: '#/components/responses/Unauthorized'
|
||||||
|
409:
|
||||||
|
description: |
|
||||||
|
There is already a client with the user ID of that token.
|
||||||
|
This should usually not happen, because the user ID was just created.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Error'
|
||||||
500:
|
500:
|
||||||
$ref: '#/components/responses/MatrixServerError'
|
$ref: '#/components/responses/MatrixServerError'
|
||||||
'/client/auth/{server}/login':
|
'/client/auth/{server}/login':
|
||||||
|
@ -497,6 +514,12 @@ paths:
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
- name: update_client
|
||||||
|
in: query
|
||||||
|
description: Should maubot store the access details in a Client instead of returning them?
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: boolean
|
||||||
post:
|
post:
|
||||||
operationId: client_auth_login
|
operationId: client_auth_login
|
||||||
summary: Log in to the given Matrix server via the maubot server
|
summary: Log in to the given Matrix server via the maubot server
|
||||||
|
@ -519,10 +542,22 @@ paths:
|
||||||
example: '@putkiteippi:maunium.net'
|
example: '@putkiteippi:maunium.net'
|
||||||
access_token:
|
access_token:
|
||||||
type: string
|
type: string
|
||||||
example: token_here
|
example: syt_123_456_789
|
||||||
device_id:
|
device_id:
|
||||||
type: string
|
type: string
|
||||||
example: device_id_here
|
example: maubot_F00BAR12
|
||||||
|
201:
|
||||||
|
description: Client created (when update_client is true)
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/MatrixClient'
|
||||||
|
202:
|
||||||
|
description: Client updated (when update_client is true)
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/MatrixClient'
|
||||||
401:
|
401:
|
||||||
$ref: '#/components/responses/Unauthorized'
|
$ref: '#/components/responses/Unauthorized'
|
||||||
500:
|
500:
|
||||||
|
@ -641,6 +676,9 @@ components:
|
||||||
access_token:
|
access_token:
|
||||||
type: string
|
type: string
|
||||||
description: The Matrix access token for this client.
|
description: The Matrix access token for this client.
|
||||||
|
device_id:
|
||||||
|
type: string
|
||||||
|
description: The Matrix device ID corresponding to the access token.
|
||||||
enabled:
|
enabled:
|
||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
|
|
|
@ -144,13 +144,13 @@ async def main():
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
whoami_user_id = await client.whoami()
|
whoami = await client.whoami()
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception("Failed to connect to homeserver, retrying in 10 seconds...")
|
log.exception("Failed to connect to homeserver, retrying in 10 seconds...")
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
continue
|
continue
|
||||||
if whoami_user_id != user_id:
|
if whoami.user_id != user_id:
|
||||||
log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami_user_id}")
|
log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami.user_id}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue