Add SSO support to mbc auth
This commit is contained in:
parent
f2bae18c7a
commit
ca7a980081
7 changed files with 177 additions and 40 deletions
|
@ -73,6 +73,11 @@ def command(help: str) -> Callable[[Callable], Callable]:
|
||||||
required_unless = questions[key].pop("required_unless")
|
required_unless = questions[key].pop("required_unless")
|
||||||
if isinstance(required_unless, str) and kwargs[required_unless]:
|
if isinstance(required_unless, str) and kwargs[required_unless]:
|
||||||
questions.pop(key)
|
questions.pop(key)
|
||||||
|
elif isinstance(required_unless, list):
|
||||||
|
for v in required_unless:
|
||||||
|
if kwargs[v]:
|
||||||
|
questions.pop(key)
|
||||||
|
break
|
||||||
elif isinstance(required_unless, dict):
|
elif isinstance(required_unless, dict):
|
||||||
for k, v in required_unless.items():
|
for k, v in required_unless.items():
|
||||||
if kwargs.get(v, object()) == v:
|
if kwargs.get(v, object()) == v:
|
||||||
|
@ -118,7 +123,7 @@ def option(short: str, long: str, message: str = None, help: str = None,
|
||||||
click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None,
|
click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None,
|
||||||
validator: Type[Validator] = None, required: bool = False,
|
validator: Type[Validator] = None, required: bool = False,
|
||||||
default: Union[str, bool, None] = None, is_flag: bool = False, prompt: bool = True,
|
default: Union[str, bool, None] = None, is_flag: bool = False, prompt: bool = True,
|
||||||
required_unless: str = None) -> Callable[[Callable], Callable]:
|
required_unless: Union[str, list, dict] = None) -> Callable[[Callable], Callable]:
|
||||||
if not message:
|
if not message:
|
||||||
message = long[2].upper() + long[3:]
|
message = long[2].upper() + long[3:]
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
import webbrowser
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
|
@ -28,7 +29,9 @@ friendly_errors = {
|
||||||
"server_not_found": "Registration target server not found.\n\n"
|
"server_not_found": "Registration target server not found.\n\n"
|
||||||
"To log in or register through maubot, you must add the server to the\n"
|
"To log in or register through maubot, you must add the server to the\n"
|
||||||
"homeservers section in the config. If you only want to log in,\n"
|
"homeservers section in the config. If you only want to log in,\n"
|
||||||
"leave the `secret` field empty."
|
"leave the `secret` field empty.",
|
||||||
|
"registration_no_sso": "The register operation is only for registering with a password.\n\n"
|
||||||
|
"To register with SSO, simply leave out the --register flag.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,9 +46,10 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
|
||||||
|
|
||||||
@cliq.command(help="Log into a Matrix account via the Maubot server")
|
@cliq.command(help="Log into a Matrix account via the Maubot server")
|
||||||
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
|
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
|
||||||
@cliq.option("-u", "--username", help="The username to log in with", required_unless="list")
|
@cliq.option("-u", "--username", help="The username to log in with",
|
||||||
|
required_unless=["list", "sso"])
|
||||||
@cliq.option("-p", "--password", help="The password to log in with", inq_type="password",
|
@cliq.option("-p", "--password", help="The password to log in with", inq_type="password",
|
||||||
required_unless="list")
|
required_unless=["list", "sso"])
|
||||||
@cliq.option("-s", "--server", help="The maubot instance to log in through", default="",
|
@cliq.option("-s", "--server", help="The maubot instance to log in through", default="",
|
||||||
required=False, prompt=False)
|
required=False, prompt=False)
|
||||||
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
|
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
|
||||||
|
@ -54,25 +58,52 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
|
||||||
"create or update a client in maubot using it",
|
"create or update a client in maubot using it",
|
||||||
is_flag=True, default=False)
|
is_flag=True, default=False)
|
||||||
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
|
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
|
||||||
|
@click.option("-o", "--sso", help="Use single sign-on instead of password login",
|
||||||
|
is_flag=True, default=False)
|
||||||
@click.option("-n", "--device-name", help="The initial e2ee device displayname (only for login)",
|
@click.option("-n", "--device-name", help="The initial e2ee device displayname (only for login)",
|
||||||
default="Maubot", required=False)
|
default="Maubot", required=False)
|
||||||
@cliq.with_authenticated_http
|
@cliq.with_authenticated_http
|
||||||
async def auth(homeserver: str, username: str, password: str, server: str, register: bool,
|
async def auth(homeserver: str, username: str, password: str, server: str, register: bool,
|
||||||
list: bool, update_client: bool, device_name: str, sess: aiohttp.ClientSession
|
list: bool, update_client: bool, device_name: str, sso: bool,
|
||||||
) -> None:
|
sess: aiohttp.ClientSession) -> None:
|
||||||
if list:
|
if list:
|
||||||
await list_servers(server, sess)
|
await list_servers(server, sess)
|
||||||
return
|
return
|
||||||
endpoint = "register" if register else "login"
|
endpoint = "register" if register else "login"
|
||||||
url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint
|
url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint
|
||||||
if update_client:
|
if update_client:
|
||||||
url = url.with_query({"update_client": "true"})
|
url = url.update_query({"update_client": "true"})
|
||||||
|
if sso:
|
||||||
|
url = url.update_query({"sso": "true"})
|
||||||
|
req_data = {"device_name": device_name}
|
||||||
|
else:
|
||||||
req_data = {"username": username, "password": password, "device_name": device_name}
|
req_data = {"username": username, "password": password, "device_name": device_name}
|
||||||
|
|
||||||
|
action = "registered" if register else "logged in as"
|
||||||
async with sess.post(url, json=req_data) as resp:
|
async with sess.post(url, json=req_data) as resp:
|
||||||
|
if not 200 <= resp.status < 300:
|
||||||
|
await print_error(resp, action)
|
||||||
|
elif sso:
|
||||||
|
await wait_sso(resp, sess, server, homeserver)
|
||||||
|
else:
|
||||||
|
await print_response(resp, action)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_sso(resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession,
|
||||||
|
server: str, homeserver: str) -> None:
|
||||||
|
data = await resp.json()
|
||||||
|
sso_url, reg_id = data["sso_url"], data["id"]
|
||||||
|
print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}")
|
||||||
|
webbrowser.open(sso_url, autoraise=True)
|
||||||
|
print(f"{Fore.GREEN}Waiting for login token...{Fore.RESET}")
|
||||||
|
wait_url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / "sso" / reg_id / "wait"
|
||||||
|
async with sess.post(wait_url, json={}) as resp:
|
||||||
|
await print_response(resp, "logged in as")
|
||||||
|
|
||||||
|
|
||||||
|
async def print_response(resp: aiohttp.ClientResponse, action: str) -> None:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
data = await resp.json()
|
data = await resp.json()
|
||||||
action = "registered" if register else "logged in as"
|
|
||||||
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
|
print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.")
|
||||||
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
|
print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}")
|
||||||
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
|
print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}")
|
||||||
|
@ -83,10 +114,13 @@ async def auth(homeserver: str, username: str, password: str, server: str, regis
|
||||||
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
|
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
|
||||||
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
|
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
|
||||||
else:
|
else:
|
||||||
|
await print_error(resp, action)
|
||||||
|
|
||||||
|
|
||||||
|
async def print_error(resp: aiohttp.ClientResponse, action: str) -> None:
|
||||||
try:
|
try:
|
||||||
err_data = await resp.json()
|
err_data = await resp.json()
|
||||||
error = friendly_errors.get(err_data["errcode"], err_data["error"])
|
error = friendly_errors.get(err_data["errcode"], err_data["error"])
|
||||||
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
|
except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError):
|
||||||
error = await resp.text()
|
error = await resp.text()
|
||||||
action = "register" if register else "log in"
|
|
||||||
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")
|
print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}")
|
||||||
|
|
|
@ -241,7 +241,8 @@ class Client:
|
||||||
"homeserver": self.homeserver,
|
"homeserver": self.homeserver,
|
||||||
"access_token": self.access_token,
|
"access_token": self.access_token,
|
||||||
"device_id": self.device_id,
|
"device_id": self.device_id,
|
||||||
"fingerprint": self.crypto.account.fingerprint if self.crypto else None,
|
"fingerprint": (self.crypto.account.fingerprint if self.crypto and self.crypto.account
|
||||||
|
else None),
|
||||||
"enabled": self.enabled,
|
"enabled": self.enabled,
|
||||||
"started": self.started,
|
"started": self.started,
|
||||||
"sync": self.sync,
|
"sync": self.sync,
|
||||||
|
|
|
@ -68,8 +68,8 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
|
||||||
enabled=data.get("enabled", True), next_batch=SyncToken(""),
|
enabled=data.get("enabled", True), next_batch=SyncToken(""),
|
||||||
filter_id=FilterID(""), sync=data.get("sync", True),
|
filter_id=FilterID(""), sync=data.get("sync", True),
|
||||||
autojoin=data.get("autojoin", True), online=data.get("online", True),
|
autojoin=data.get("autojoin", True), online=data.get("online", True),
|
||||||
displayname=data.get("displayname", ""),
|
displayname=data.get("displayname", "disable"),
|
||||||
avatar_url=data.get("avatar_url", ""),
|
avatar_url=data.get("avatar_url", "disable"),
|
||||||
device_id=device_id)
|
device_id=device_id)
|
||||||
client = Client(db_instance)
|
client = Client(db_instance)
|
||||||
client.db_instance.insert()
|
client.db_instance.insert()
|
||||||
|
|
|
@ -17,12 +17,15 @@ from typing import Dict, Tuple, NamedTuple, Optional
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import asyncio
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
import hmac
|
import hmac
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from mautrix.api import SynapseAdminPath, Method
|
from yarl import URL
|
||||||
|
|
||||||
|
from mautrix.api import SynapseAdminPath, Method, Path
|
||||||
from mautrix.errors import MatrixRequestError
|
from mautrix.errors import MatrixRequestError
|
||||||
from mautrix.client import ClientAPI
|
from mautrix.client import ClientAPI
|
||||||
from mautrix.types import LoginType, LoginResponse
|
from mautrix.types import LoginType, LoginResponse
|
||||||
|
@ -42,6 +45,7 @@ async def get_known_servers(_: web.Request) -> web.Response:
|
||||||
|
|
||||||
|
|
||||||
class AuthRequestInfo(NamedTuple):
|
class AuthRequestInfo(NamedTuple):
|
||||||
|
server_name: str
|
||||||
client: ClientAPI
|
client: ClientAPI
|
||||||
secret: str
|
secret: str
|
||||||
username: str
|
username: str
|
||||||
|
@ -49,6 +53,10 @@ class AuthRequestInfo(NamedTuple):
|
||||||
user_type: str
|
user_type: str
|
||||||
device_name: str
|
device_name: str
|
||||||
update_client: bool
|
update_client: bool
|
||||||
|
sso: bool
|
||||||
|
|
||||||
|
|
||||||
|
truthy_strings = ("1", "true", "yes")
|
||||||
|
|
||||||
|
|
||||||
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
|
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
|
||||||
|
@ -61,23 +69,28 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
return None, resp.body_not_json
|
return None, resp.body_not_json
|
||||||
|
sso = request.query.get("sso", "").lower() in truthy_strings
|
||||||
try:
|
try:
|
||||||
username = body["username"]
|
username = body["username"]
|
||||||
password = body["password"]
|
password = body["password"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
if not sso:
|
||||||
return None, resp.username_or_password_missing
|
return None, resp.username_or_password_missing
|
||||||
|
username = password = None
|
||||||
try:
|
try:
|
||||||
base_url = server["url"]
|
base_url = server["url"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None, resp.invalid_server
|
return None, resp.invalid_server
|
||||||
return AuthRequestInfo(
|
return AuthRequestInfo(
|
||||||
|
server_name=server_name,
|
||||||
client=ClientAPI(base_url=base_url, loop=get_loop()),
|
client=ClientAPI(base_url=base_url, loop=get_loop()),
|
||||||
secret=server.get("secret"),
|
secret=server.get("secret"),
|
||||||
username=username,
|
username=username,
|
||||||
password=password,
|
password=password,
|
||||||
user_type=body.get("user_type", "bot"),
|
user_type=body.get("user_type", "bot"),
|
||||||
device_name=body.get("device_name", "Maubot"),
|
device_name=body.get("device_name", "Maubot"),
|
||||||
update_client=request.query.get("update_client", "").lower() in ("1", "true", "yes"),
|
update_client=request.query.get("update_client", "").lower() in truthy_strings,
|
||||||
|
sso=sso,
|
||||||
), None
|
), None
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,7 +115,9 @@ async def register(request: web.Request) -> web.Response:
|
||||||
req, err = await read_client_auth_request(request)
|
req, err = await read_client_auth_request(request)
|
||||||
if err is not None:
|
if err is not None:
|
||||||
return err
|
return err
|
||||||
if not req.secret:
|
if req.sso:
|
||||||
|
return resp.registration_no_sso
|
||||||
|
elif not req.secret:
|
||||||
return resp.registration_secret_not_found
|
return resp.registration_secret_not_found
|
||||||
path = SynapseAdminPath.v1.register
|
path = SynapseAdminPath.v1.register
|
||||||
res = await req.client.api.request(Method.GET, path)
|
res = await req.client.api.request(Method.GET, path)
|
||||||
|
@ -137,10 +152,38 @@ async def login(request: web.Request) -> web.Response:
|
||||||
req, err = await read_client_auth_request(request)
|
req, err = await read_client_auth_request(request)
|
||||||
if err is not None:
|
if err is not None:
|
||||||
return err
|
return err
|
||||||
device_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
if req.sso:
|
||||||
|
return await _do_sso(req)
|
||||||
|
else:
|
||||||
|
return await _do_login(req)
|
||||||
|
|
||||||
|
|
||||||
|
async def _do_sso(req: AuthRequestInfo) -> web.Response:
|
||||||
|
flows = await req.client.get_login_flows()
|
||||||
|
if not flows.supports_type(LoginType.SSO):
|
||||||
|
return resp.sso_not_supported
|
||||||
|
waiter_id = ''.join(random.choices(string.ascii_lowercase + string.digits, k=16))
|
||||||
|
cfg = get_config()
|
||||||
|
public_url = (URL(cfg["server.public_url"]) / cfg["server.base_path"].lstrip("/")
|
||||||
|
/ "client/auth_external_sso/complete" / waiter_id)
|
||||||
|
sso_url = (req.client.api.base_url
|
||||||
|
.with_path(str(Path.login.sso.redirect))
|
||||||
|
.with_query({"redirectUrl": str(public_url)}))
|
||||||
|
sso_waiters[waiter_id] = req, get_loop().create_future()
|
||||||
|
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response:
|
||||||
|
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
||||||
|
device_id = f"maubot_{device_id}"
|
||||||
try:
|
try:
|
||||||
|
if req.sso:
|
||||||
|
res = await req.client.login(token=login_token, login_type=LoginType.TOKEN,
|
||||||
|
device_id=device_id, store_access_token=False,
|
||||||
|
initial_device_display_name=req.device_name)
|
||||||
|
else:
|
||||||
res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD,
|
res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD,
|
||||||
password=req.password, device_id=f"maubot_{device_id}",
|
password=req.password, device_id=device_id,
|
||||||
initial_device_display_name=req.device_name,
|
initial_device_display_name=req.device_name,
|
||||||
store_access_token=False)
|
store_access_token=False)
|
||||||
except MatrixRequestError as e:
|
except MatrixRequestError as e:
|
||||||
|
@ -155,3 +198,38 @@ async def login(request: web.Request) -> web.Response:
|
||||||
"device_id": res.device_id,
|
"device_id": res.device_id,
|
||||||
}, is_login=True)
|
}, is_login=True)
|
||||||
return web.json_response(res.serialize())
|
return web.json_response(res.serialize())
|
||||||
|
|
||||||
|
|
||||||
|
sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@routes.post("/client/auth/{server}/sso/{id}/wait")
|
||||||
|
async def wait_sso(request: web.Request) -> web.Response:
|
||||||
|
waiter_id = request.match_info["id"]
|
||||||
|
req, fut = sso_waiters[waiter_id]
|
||||||
|
try:
|
||||||
|
login_token = await fut
|
||||||
|
finally:
|
||||||
|
sso_waiters.pop(waiter_id, None)
|
||||||
|
return await _do_login(req, login_token)
|
||||||
|
|
||||||
|
|
||||||
|
@routes.get("/client/auth_external_sso/complete/{id}")
|
||||||
|
async def complete_sso(request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
_, fut = sso_waiters[request.match_info["id"]]
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=404, text="Invalid session ID\n")
|
||||||
|
if fut.cancelled():
|
||||||
|
return web.Response(status=200, text="The login was cancelled from the Maubot client\n")
|
||||||
|
elif fut.done():
|
||||||
|
return web.Response(status=200, text="The login token was already received\n")
|
||||||
|
try:
|
||||||
|
fut.set_result(request.query["loginToken"])
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=400, text="Missing loginToken query parameter\n")
|
||||||
|
except asyncio.InvalidStateError:
|
||||||
|
return web.Response(status=500, text="Invalid state\n")
|
||||||
|
return web.Response(status=200,
|
||||||
|
text="Login token received, please return to your Maubot client. "
|
||||||
|
"This tab can be closed.\n")
|
||||||
|
|
|
@ -29,7 +29,12 @@ log = logging.getLogger("maubot.server")
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def auth(request: web.Request, handler: Handler) -> web.Response:
|
async def auth(request: web.Request, handler: Handler) -> web.Response:
|
||||||
subpath = request.path[len(get_config()["server.base_path"]):]
|
subpath = request.path[len(get_config()["server.base_path"]):]
|
||||||
if subpath.startswith("/auth/") or subpath == "/features" or subpath == "/logs":
|
if (
|
||||||
|
subpath.startswith("/auth/")
|
||||||
|
or subpath.startswith("/client/auth_external_sso/complete/")
|
||||||
|
or subpath == "/features"
|
||||||
|
or subpath == "/logs"
|
||||||
|
):
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
err = check_token(request)
|
err = check_token(request)
|
||||||
if err is not None:
|
if err is not None:
|
||||||
|
|
|
@ -194,6 +194,20 @@ class _Response:
|
||||||
"errcode": "registration_secret_not_found",
|
"errcode": "registration_secret_not_found",
|
||||||
}, status=HTTPStatus.NOT_FOUND)
|
}, status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def registration_no_sso(self) -> web.Response:
|
||||||
|
return web.json_response({
|
||||||
|
"error": "The register operation is only for registering with a password",
|
||||||
|
"errcode": "registration_no_sso",
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sso_not_supported(self) -> web.Response:
|
||||||
|
return web.json_response({
|
||||||
|
"error": "That server does not seem to support single sign-on",
|
||||||
|
"errcode": "sso_not_supported",
|
||||||
|
}, status=HTTPStatus.FORBIDDEN)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def plugin_has_no_database(self) -> web.Response:
|
def plugin_has_no_database(self) -> web.Response:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
|
|
Loading…
Reference in a new issue