Add autojoin event handler

This commit is contained in:
Tulir Asokan 2018-10-13 00:30:05 +03:00
parent 6969b2e273
commit 91dc5646c7

View file

@ -17,8 +17,9 @@ from typing import Dict, Optional
from aiohttp import ClientSession
import logging
from mautrix import ClientAPI
from mautrix.types import UserID, SyncToken, FilterID, ContentURI
from mautrix import Client as MatrixClient
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
EventType)
from .db import DBClient
@ -32,11 +33,13 @@ class Client:
def __init__(self, db_instance: DBClient) -> None:
self.db_instance: DBClient = db_instance
self.cache[self.id] = self
self.client: ClientAPI = ClientAPI(mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=log.getChild(self.id))
self.client: MatrixClient = MatrixClient(mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=log.getChild(self.id))
if self.autojoin:
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
@classmethod
def get(cls, id: UserID) -> Optional['Client']:
@ -49,6 +52,7 @@ class Client:
return Client(db_instance)
# region Properties
@property
def id(self) -> UserID:
return self.db_instance.id
@ -63,6 +67,7 @@ class Client:
@access_token.setter
def access_token(self, value: str) -> None:
self.client.api.token = value
self.db_instance.access_token = value
@property
@ -95,6 +100,12 @@ class Client:
@autojoin.setter
def autojoin(self, value: bool) -> None:
if value == self.db_instance.autojoin:
return
if value:
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
else:
self.client.remove_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
self.db_instance.autojoin = value
@property
@ -114,3 +125,7 @@ class Client:
self.db_instance.avatar_url = value
# endregion
async def handle_invite(self, evt: StateEvent):
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)