Add SSO support to mbc auth
This commit is contained in:
parent
f2bae18c7a
commit
ca7a980081
7 changed files with 177 additions and 40 deletions
maubot
|
@ -73,6 +73,11 @@ def command(help: str) -> Callable[[Callable], Callable]:
|
|||
required_unless = questions[key].pop("required_unless")
|
||||
if isinstance(required_unless, str) and kwargs[required_unless]:
|
||||
questions.pop(key)
|
||||
elif isinstance(required_unless, list):
|
||||
for v in required_unless:
|
||||
if kwargs[v]:
|
||||
questions.pop(key)
|
||||
break
|
||||
elif isinstance(required_unless, dict):
|
||||
for k, v in required_unless.items():
|
||||
if kwargs.get(v, object()) == v:
|
||||
|
@ -118,7 +123,7 @@ def option(short: str, long: str, message: str = None, help: str = None,
|
|||
click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None,
|
||||
validator: Type[Validator] = None, required: bool = False,
|
||||
default: Union[str, bool, None] = None, is_flag: bool = False, prompt: bool = True,
|
||||
required_unless: str = None) -> Callable[[Callable], Callable]:
|
||||
required_unless: Union[str, list, dict] = None) -> Callable[[Callable], Callable]:
|
||||
if not message:
|
||||
message = long[2].upper() + long[3:]
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
import webbrowser
|
||||
import json
|
||||
|
||||
from colorama import Fore
|
||||
|
@ -28,7 +29,9 @@ friendly_errors = {
|
|||
"server_not_found": "Registration target server not found.\n\n"
|
||||
"To log in or register through maubot, you must add the server to the\n"
|
||||
"homeservers section in the config. If you only want to log in,\n"
|
||||
"leave the `secret` field empty."
|
||||
"leave the `secret` field empty.",
|
||||
"registration_no_sso": "The register operation is only for registering with a password.\n\n"
|
||||
"To register with SSO, simply leave out the --register flag.",
|
||||
}
|
||||
|
||||
|
||||
|
@ -43,9 +46,10 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
|
|||
|
||||
@cliq.command(help="Log into a Matrix account via the Maubot server")
|
||||
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
|
||||
@cliq.option("-u", "--username", help="The username to log in with", required_unless="list")
|
||||
@cliq.option("-u", "--username", help="The username to log in with",
|
||||
required_unless=["list", "sso"])
|
||||
@cliq.option("-p", "--password", help="The password to log in with", inq_type="password",
|
||||
required_unless="list")
|
||||
required_unless=["list", "sso"])
|
||||
@cliq.option("-s", "--server", help="The maubot instance to log in through", default="",
|
||||
required=False, prompt=False)
|
||||
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
|
||||
|
@ -54,25 +58,52 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
|
|||
"create or update a client in maubot using it",
|
||||
is_flag=True, default=False)
|
||||
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
|
||||
@click.option("-o", "--sso", help="Use single sign-on instead of password login",
|
||||
is_flag=True, default=False)
|
||||
@click.option("-n", "--device-name", help="The initial e2ee device displayname (only for login)",
|
||||
default="Maubot", required=False)
|
||||
@cliq.with_authenticated_http
|
||||
async def auth(homeserver: str, username: str, password: str, server: str, register: bool,
|
||||
list: bool, update_client: bool, device_name: str, sess: aiohttp.ClientSession
|
||||
) -> None:
|
||||
list: bool, update_client: bool, device_name: str, sso: bool,
|
||||
sess: aiohttp.ClientSession) -> None:
|
||||
if list:
|
||||
await list_servers(server, sess)
|
||||
return
|
||||
endpoint = "register" if register else "login"
|
||||
url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint
|
||||
if update_client:
|
||||
url = url.with_query({"update_client": "true"})
|
||||
url = url.update_query({"update_client": "true"})
|
||||
if sso:
|
||||
url = url.update_query({"sso": "true"})
|
||||
req_data = {"device_name": device_name}
|
||||
else:
|
||||
req_data = {"username": username, "password": password, "device_name": device_name}
|
||||
|
||||
action = "registered" if register else "logged in as"
|
||||
async with sess.post(url, json=req_data) as resp:
|
||||
if not 200 <= resp.status < 300:
|
||||
await print_error(resp, action)
|
||||
elif sso:
|
||||
await wait_sso(resp, sess, server, homeserver)
|
||||
else:
|
||||
await print_response(resp, action)
|
||||
|
||||
|
||||
async def wait_sso(resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession,
|
||||
server: str, homeserver: str) -> None:
|
||||
data = await resp.json()
|
||||
sso_url, reg_id = data["sso_url"], data["id"]
|
||||
print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}")
|
||||
webbrowser.open(sso_url, autoraise=True)
|
||||
print(f"{Fore.GREEN}Waiting for login token...{Fore.RESET}")
|
||||
wait_url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / "sso" / reg_id / "wait"
|
||||
async with sess.post(wait_url, json={}) as resp:
|
||||
await print_response(resp, "logged in as")
|
||||
|
||||
|
||||
async def print_response(resp: aiohttp.ClientResponse, action: str) -> None:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
action = "registered" if register else "logged in as"
|
||||
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
|
||||
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
|
||||
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
|
||||
|
@ -83,10 +114,13 @@ async def auth(homeserver: str, username: str, password: str, server: str, regis
|
|||
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
|
||||
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
|
||||
else:
|
||||
await print_error(resp, action)
|
||||
|
||||
|
||||
async def print_error(resp: aiohttp.ClientResponse, action: str) -> None:
|
||||
try:
|
||||
err_data = await resp.json()
|
||||
error = friendly_errors.get(err_data["errcode"], err_data["error"])
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
|
||||
error = await resp.text()
|
||||
action = "register" if register else "log in"
|
||||
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")
|
||||
|
|
|
@ -241,7 +241,8 @@ class Client:
|
|||
"homeserver": self.homeserver,
|
||||
"access_token": self.access_token,
|
||||
"device_id": self.device_id,
|
||||
"fingerprint": self.crypto.account.fingerprint if self.crypto else None,
|
||||
"fingerprint": (self.crypto.account.fingerprint if self.crypto and self.crypto.account
|
||||
else None),
|
||||
"enabled": self.enabled,
|
||||
"started": self.started,
|
||||
"sync": self.sync,
|
||||
|
|
|
@ -68,8 +68,8 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
|
|||
enabled=data.get("enabled", True), next_batch=SyncToken(""),
|
||||
filter_id=FilterID(""), sync=data.get("sync", True),
|
||||
autojoin=data.get("autojoin", True), online=data.get("online", True),
|
||||
displayname=data.get("displayname", ""),
|
||||
avatar_url=data.get("avatar_url", ""),
|
||||
displayname=data.get("displayname", "disable"),
|
||||
avatar_url=data.get("avatar_url", "disable"),
|
||||
device_id=device_id)
|
||||
client = Client(db_instance)
|
||||
client.db_instance.insert()
|
||||
|
|
|
@ -17,12 +17,15 @@ from typing import Dict, Tuple, NamedTuple, Optional
|
|||
from json import JSONDecodeError
|
||||
from http import HTTPStatus
|
||||
import hashlib
|
||||
import asyncio
|
||||
import random
|
||||
import string
|
||||
import hmac
|
||||
|
||||
from aiohttp import web
|
||||
from mautrix.api import SynapseAdminPath, Method
|
||||
from yarl import URL
|
||||
|
||||
from mautrix.api import SynapseAdminPath, Method, Path
|
||||
from mautrix.errors import MatrixRequestError
|
||||
from mautrix.client import ClientAPI
|
||||
from mautrix.types import LoginType, LoginResponse
|
||||
|
@ -42,6 +45,7 @@ async def get_known_servers(_: web.Request) -> web.Response:
|
|||
|
||||
|
||||
class AuthRequestInfo(NamedTuple):
|
||||
server_name: str
|
||||
client: ClientAPI
|
||||
secret: str
|
||||
username: str
|
||||
|
@ -49,6 +53,10 @@ class AuthRequestInfo(NamedTuple):
|
|||
user_type: str
|
||||
device_name: str
|
||||
update_client: bool
|
||||
sso: bool
|
||||
|
||||
|
||||
truthy_strings = ("1", "true", "yes")
|
||||
|
||||
|
||||
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
|
||||
|
@ -61,23 +69,28 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR
|
|||
body = await request.json()
|
||||
except JSONDecodeError:
|
||||
return None, resp.body_not_json
|
||||
sso = request.query.get("sso", "").lower() in truthy_strings
|
||||
try:
|
||||
username = body["username"]
|
||||
password = body["password"]
|
||||
except KeyError:
|
||||
if not sso:
|
||||
return None, resp.username_or_password_missing
|
||||
username = password = None
|
||||
try:
|
||||
base_url = server["url"]
|
||||
except KeyError:
|
||||
return None, resp.invalid_server
|
||||
return AuthRequestInfo(
|
||||
server_name=server_name,
|
||||
client=ClientAPI(base_url=base_url, loop=get_loop()),
|
||||
secret=server.get("secret"),
|
||||
username=username,
|
||||
password=password,
|
||||
user_type=body.get("user_type", "bot"),
|
||||
device_name=body.get("device_name", "Maubot"),
|
||||
update_client=request.query.get("update_client", "").lower() in ("1", "true", "yes"),
|
||||
update_client=request.query.get("update_client", "").lower() in truthy_strings,
|
||||
sso=sso,
|
||||
), None
|
||||
|
||||
|
||||
|
@ -102,7 +115,9 @@ async def register(request: web.Request) -> web.Response:
|
|||
req, err = await read_client_auth_request(request)
|
||||
if err is not None:
|
||||
return err
|
||||
if not req.secret:
|
||||
if req.sso:
|
||||
return resp.registration_no_sso
|
||||
elif not req.secret:
|
||||
return resp.registration_secret_not_found
|
||||
path = SynapseAdminPath.v1.register
|
||||
res = await req.client.api.request(Method.GET, path)
|
||||
|
@ -137,10 +152,38 @@ async def login(request: web.Request) -> web.Response:
|
|||
req, err = await read_client_auth_request(request)
|
||||
if err is not None:
|
||||
return err
|
||||
device_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
||||
if req.sso:
|
||||
return await _do_sso(req)
|
||||
else:
|
||||
return await _do_login(req)
|
||||
|
||||
|
||||
async def _do_sso(req: AuthRequestInfo) -> web.Response:
|
||||
flows = await req.client.get_login_flows()
|
||||
if not flows.supports_type(LoginType.SSO):
|
||||
return resp.sso_not_supported
|
||||
waiter_id = ''.join(random.choices(string.ascii_lowercase + string.digits, k=16))
|
||||
cfg = get_config()
|
||||
public_url = (URL(cfg["server.public_url"]) / cfg["server.base_path"].lstrip("/")
|
||||
/ "client/auth_external_sso/complete" / waiter_id)
|
||||
sso_url = (req.client.api.base_url
|
||||
.with_path(str(Path.login.sso.redirect))
|
||||
.with_query({"redirectUrl": str(public_url)}))
|
||||
sso_waiters[waiter_id] = req, get_loop().create_future()
|
||||
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
|
||||
|
||||
|
||||
async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response:
|
||||
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
||||
device_id = f"maubot_{device_id}"
|
||||
try:
|
||||
if req.sso:
|
||||
res = await req.client.login(token=login_token, login_type=LoginType.TOKEN,
|
||||
device_id=device_id, store_access_token=False,
|
||||
initial_device_display_name=req.device_name)
|
||||
else:
|
||||
res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD,
|
||||
password=req.password, device_id=f"maubot_{device_id}",
|
||||
password=req.password, device_id=device_id,
|
||||
initial_device_display_name=req.device_name,
|
||||
store_access_token=False)
|
||||
except MatrixRequestError as e:
|
||||
|
@ -155,3 +198,38 @@ async def login(request: web.Request) -> web.Response:
|
|||
"device_id": res.device_id,
|
||||
}, is_login=True)
|
||||
return web.json_response(res.serialize())
|
||||
|
||||
|
||||
sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {}
|
||||
|
||||
|
||||
@routes.post("/client/auth/{server}/sso/{id}/wait")
|
||||
async def wait_sso(request: web.Request) -> web.Response:
|
||||
waiter_id = request.match_info["id"]
|
||||
req, fut = sso_waiters[waiter_id]
|
||||
try:
|
||||
login_token = await fut
|
||||
finally:
|
||||
sso_waiters.pop(waiter_id, None)
|
||||
return await _do_login(req, login_token)
|
||||
|
||||
|
||||
@routes.get("/client/auth_external_sso/complete/{id}")
|
||||
async def complete_sso(request: web.Request) -> web.Response:
|
||||
try:
|
||||
_, fut = sso_waiters[request.match_info["id"]]
|
||||
except KeyError:
|
||||
return web.Response(status=404, text="Invalid session ID\n")
|
||||
if fut.cancelled():
|
||||
return web.Response(status=200, text="The login was cancelled from the Maubot client\n")
|
||||
elif fut.done():
|
||||
return web.Response(status=200, text="The login token was already received\n")
|
||||
try:
|
||||
fut.set_result(request.query["loginToken"])
|
||||
except KeyError:
|
||||
return web.Response(status=400, text="Missing loginToken query parameter\n")
|
||||
except asyncio.InvalidStateError:
|
||||
return web.Response(status=500, text="Invalid state\n")
|
||||
return web.Response(status=200,
|
||||
text="Login token received, please return to your Maubot client. "
|
||||
"This tab can be closed.\n")
|
||||
|
|
|
@ -29,7 +29,12 @@ log = logging.getLogger("maubot.server")
|
|||
@web.middleware
|
||||
async def auth(request: web.Request, handler: Handler) -> web.Response:
|
||||
subpath = request.path[len(get_config()["server.base_path"]):]
|
||||
if subpath.startswith("/auth/") or subpath == "/features" or subpath == "/logs":
|
||||
if (
|
||||
subpath.startswith("/auth/")
|
||||
or subpath.startswith("/client/auth_external_sso/complete/")
|
||||
or subpath == "/features"
|
||||
or subpath == "/logs"
|
||||
):
|
||||
return await handler(request)
|
||||
err = check_token(request)
|
||||
if err is not None:
|
||||
|
|
|
@ -194,6 +194,20 @@ class _Response:
|
|||
"errcode": "registration_secret_not_found",
|
||||
}, status=HTTPStatus.NOT_FOUND)
|
||||
|
||||
@property
|
||||
def registration_no_sso(self) -> web.Response:
|
||||
return web.json_response({
|
||||
"error": "The register operation is only for registering with a password",
|
||||
"errcode": "registration_no_sso",
|
||||
}, status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
@property
|
||||
def sso_not_supported(self) -> web.Response:
|
||||
return web.json_response({
|
||||
"error": "That server does not seem to support single sign-on",
|
||||
"errcode": "sso_not_supported",
|
||||
}, status=HTTPStatus.FORBIDDEN)
|
||||
|
||||
@property
|
||||
def plugin_has_no_database(self) -> web.Response:
|
||||
return web.json_response({
|
||||
|
|
Loading…
Reference in a new issue