Add support for end-to-end encryption. Fixes #46

This commit is contained in:
Tulir Asokan 2020-07-12 14:55:41 +03:00
parent 4e767a10e4
commit 69d7a4341b
17 changed files with 203 additions and 24 deletions

4
MANIFEST.in Normal file
View file

@ -0,0 +1,4 @@
include README.md
include LICENSE
include requirements.txt
include optional-requirements.txt

View file

@ -0,0 +1,47 @@
"""Add Matrix state store
Revision ID: 90aa88820eab
Revises: 4b93300852aa
Create Date: 2020-07-12 01:50:06.215623
"""
from alembic import op
import sqlalchemy as sa
from mautrix.client.state_store.sqlalchemy import SerializableType
from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent
# revision identifiers, used by Alembic.
revision = '90aa88820eab'
down_revision = '4b93300852aa'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('mx_room_state',
sa.Column('room_id', sa.String(length=255), nullable=False),
sa.Column('is_encrypted', sa.Boolean(), nullable=True),
sa.Column('has_full_member_list', sa.Boolean(), nullable=True),
sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True),
sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True),
sa.PrimaryKeyConstraint('room_id')
)
op.create_table('mx_user_profile',
sa.Column('room_id', sa.String(length=255), nullable=False),
sa.Column('user_id', sa.String(length=255), nullable=False),
sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False),
sa.Column('displayname', sa.String(), nullable=True),
sa.Column('avatar_url', sa.String(length=255), nullable=True),
sa.PrimaryKeyConstraint('room_id', 'user_id')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('mx_user_profile')
op.drop_table('mx_room_state')
# ### end Alembic commands ###

View file

@ -5,6 +5,18 @@
# Postgres: postgres://username:password@hostname/dbname # Postgres: postgres://username:password@hostname/dbname
database: sqlite:///maubot.db database: sqlite:///maubot.db
# Database for encryption data.
crypto_database:
# Type of database. Either "default", "pickle" or "postgres".
# When set to default, using SQLite as the main database will use pickle as the crypto database
# and using Postgres as the main database will use the same one as the crypto database.
#
# When using pickle, individual crypto databases are stored in the pickle_dir directory.
# When using non-default postgres, postgres_uri is used to connect to postgres.
type: default
postgres_uri: postgres://username:password@hostname/dbname
pickle_dir: ./crypto
plugin_directories: plugin_directories:
# The directory where uploaded new plugins should be stored. # The directory where uploaded new plugins should be stored.
upload: ./plugins upload: ./plugins

View file

@ -57,7 +57,7 @@ log.info(f"Initializing maubot {__version__}")
init_zip_loader(config) init_zip_loader(config)
db_engine = init_db(config) db_engine = init_db(config)
clients = init_client_class(loop) clients = init_client_class(config, loop)
management_api = init_mgmt_api(config, loop) management_api = init_mgmt_api(config, loop)
server = MaubotServer(management_api, config, loop) server = MaubotServer(management_api, config, loop)
plugins = init_plugin_instance_class(config, server, loop) plugins = init_plugin_instance_class(config, server, loop)
@ -72,6 +72,9 @@ signal.signal(signal.SIGTERM, signal.default_int_handler)
try: try:
log.info("Starting server") log.info("Starting server")
loop.run_until_complete(server.start()) loop.run_until_complete(server.start())
if Client.crypto_db:
log.debug("Starting client crypto database")
loop.run_until_complete(Client.crypto_db.start())
log.info("Starting clients and plugins") log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients])) loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
log.info("Startup actions complete, running forever") log.info("Startup actions complete, running forever")

View file

@ -18,12 +18,13 @@ from io import BytesIO
import zipfile import zipfile
import os import os
from mautrix.client.api.types.util import SerializerError
from ruamel.yaml import YAML, YAMLError from ruamel.yaml import YAML, YAMLError
from colorama import Fore from colorama import Fore
from PyInquirer import prompt from PyInquirer import prompt
import click import click
from mautrix.types import SerializerError
from ...loader import PluginMeta from ...loader import PluginMeta
from ..cliq.validators import PathValidator from ..cliq.validators import PathValidator
from ..base import app from ..base import app

View file

@ -18,9 +18,10 @@ import asyncio
from colorama import Fore from colorama import Fore
from aiohttp import WSMsgType, WSMessage, ClientSession from aiohttp import WSMsgType, WSMessage, ClientSession
from mautrix.client.api.types.util import Obj
import click import click
from mautrix.types import Obj
from ..config import get_token from ..config import get_token
from ..base import app from ..base import app

View file

@ -13,24 +13,46 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
from os import path
import asyncio import asyncio
import logging import logging
from aiohttp import ClientSession from aiohttp import ClientSession
from yarl import URL
from mautrix.errors import MatrixInvalidToken, MatrixRequestError from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership, from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter, StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
PresenceState, StateFilter) PresenceState, StateFilter)
from mautrix.client import InternalEventType from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
from .lib.store_proxy import ClientStoreProxy from .lib.store_proxy import SyncStoreProxy
from .db import DBClient from .db import DBClient
from .matrix import MaubotMatrixClient from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
PickleCryptoStore)
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
except ImportError:
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
SQLStateStore = BaseSQLStateStore
try:
from mautrix.util.async_db import Database as AsyncDatabase
from mautrix.crypto import PgCryptoStore
except ImportError:
AsyncDatabase = None
PgCryptoStore = None
if TYPE_CHECKING: if TYPE_CHECKING:
from .instance import PluginInstance from .instance import PluginInstance
from .config import Config
log = logging.getLogger("maubot.client") log = logging.getLogger("maubot.client")
@ -40,10 +62,15 @@ class Client:
loop: asyncio.AbstractEventLoop = None loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {} cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None http_client: ClientSession = None
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
crypto_pickle_dir: str = None
crypto_db: 'AsyncDatabase' = None
references: Set['PluginInstance'] references: Set['PluginInstance']
db_instance: DBClient db_instance: DBClient
client: MaubotMatrixClient client: MaubotMatrixClient
crypto: Optional['OlmMachine']
crypto_store: Optional['CryptoStore']
started: bool started: bool
remote_displayname: Optional[str] remote_displayname: Optional[str]
@ -61,7 +88,15 @@ class Client:
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver, self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client, token=self.access_token, client_session=self.http_client,
log=self.log, loop=self.loop, device_id=self.device_id, log=self.log, loop=self.loop, device_id=self.device_id,
store=ClientStoreProxy(self.db_instance)) sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store)
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
self.crypto_store = self._make_crypto_store()
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
else:
self.crypto_store = None
self.crypto = None
self.client.ignore_initial_sync = True self.client.ignore_initial_sync = True
self.client.ignore_first_sync = True self.client.ignore_first_sync = True
self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE
@ -71,6 +106,14 @@ class Client:
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False)) self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True)) self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
def _make_crypto_store(self) -> 'CryptoStore':
if self.crypto_db:
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
elif self.crypto_pickle_dir:
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
raise ValueError("Crypto database not configured")
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]: def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
async def handler(data: Dict[str, Any]) -> None: async def handler(data: Dict[str, Any]) -> None:
self.sync_ok = ok self.sync_ok = ok
@ -130,6 +173,16 @@ class Client:
await self.client.set_displayname(self.displayname) await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable": if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url) await self.client.set_avatar_url(self.avatar_url)
if self.crypto:
self.log.debug("Enabling end-to-end encryption support")
await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning("Mismatching device ID in crypto store and main database. "
"Encryption may not work.")
await self.crypto.load()
if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id)
self.start_sync() self.start_sync()
await self._update_remote_profile() await self._update_remote_profile()
self.started = True self.started = True
@ -154,6 +207,8 @@ class Client:
self.started = False self.started = False
await self.stop_plugins() await self.stop_plugins()
self.stop_sync() self.stop_sync()
if self.crypto:
await self.crypto_store.close()
def clear_cache(self) -> None: def clear_cache(self) -> None:
self.stop_sync() self.stop_sync()
@ -172,6 +227,7 @@ class Client:
"id": self.id, "id": self.id,
"homeserver": self.homeserver, "homeserver": self.homeserver,
"access_token": self.access_token, "access_token": self.access_token,
"device_id": self.device_id,
"enabled": self.enabled, "enabled": self.enabled,
"started": self.started, "started": self.started,
"sync": self.sync, "sync": self.sync,
@ -243,11 +299,12 @@ class Client:
return return
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver, new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
token=access_token or self.access_token, loop=self.loop, token=access_token or self.access_token, loop=self.loop,
client_session=self.http_client, log=self.log) client_session=self.http_client, device_id=self.device_id,
log=self.log, state_store=self.global_state_store)
mxid = await new_client.whoami() mxid = await new_client.whoami()
if mxid != self.id: if mxid != self.id:
raise ValueError(f"MXID mismatch: {mxid}") raise ValueError(f"MXID mismatch: {mxid}")
new_client.store = self.db_instance new_client.sync_store = self.db_instance
self.stop_sync() self.stop_sync()
self.client = new_client self.client = new_client
self.db_instance.homeserver = homeserver self.db_instance.homeserver = homeserver
@ -341,7 +398,30 @@ class Client:
# endregion # endregion
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]: def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop) Client.http_client = ClientSession(loop=loop)
Client.loop = loop Client.loop = loop
if OlmMachine:
db_type = config["crypto_database.type"]
if db_type == "default":
db_url = config["database"]
parsed_url = URL(db_url)
if parsed_url.scheme == "sqlite":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif parsed_url.scheme == "postgres":
if not PgCryptoStore:
log.warning("Default database is postgres, but asyncpg is not installed. "
"Encryption will not work.")
else:
Client.crypto_db = AsyncDatabase(url=db_url,
upgrade_table=PgCryptoStore.upgrade_table)
elif db_type == "pickle":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif db_type == "postgres" and PgCryptoStore:
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
upgrade_table=PgCryptoStore.upgrade_table)
else:
raise ValueError("Unsupported crypto database type")
return Client.all() return Client.all()

View file

@ -32,6 +32,9 @@ class Config(BaseFileConfig):
base = helper.base base = helper.base
copy = helper.copy copy = helper.copy
copy("database") copy("database")
copy("crypto_database.type")
copy("crypto_database.postgres_uri")
copy("crypto_database.pickle_dir")
copy("plugin_directories.upload") copy("plugin_directories.upload")
copy("plugin_directories.load") copy("plugin_directories.load")
copy("plugin_directories.trash") copy("plugin_directories.trash")

View file

@ -23,6 +23,7 @@ import sqlalchemy as sql
from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI
from mautrix.util.db import Base from mautrix.util.db import Base
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
from .config import Config from .config import Config
@ -79,7 +80,7 @@ def init(config: Config) -> Engine:
db = sql.create_engine(config["database"]) db = sql.create_engine(config["database"])
Base.metadata.bind = db Base.metadata.bind = db
for table in (DBPlugin, DBClient): for table in (DBPlugin, DBClient, RoomState, UserProfile):
table.bind(db) table.bind(db)
if not db.has_table("alembic_version"): if not db.has_table("alembic_version"):

View file

@ -13,11 +13,11 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.client import ClientStore from mautrix.client import SyncStore
from mautrix.types import SyncToken from mautrix.types import SyncToken
class ClientStoreProxy(ClientStore): class SyncStoreProxy(SyncStore):
def __init__(self, db_instance) -> None: def __init__(self, db_instance) -> None:
self.db_instance = db_instance self.db_instance = db_instance

View file

@ -19,8 +19,8 @@ import asyncio
from attr import dataclass from attr import dataclass
from packaging.version import Version, InvalidVersion from packaging.version import Version, InvalidVersion
from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer,
deserializer) from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer
from ..__meta__ import __version__ from ..__meta__ import __version__
from ..plugin_base import Plugin from ..plugin_base import Plugin

View file

@ -22,7 +22,8 @@ import os
from ruamel.yaml import YAML, YAMLError from ruamel.yaml import YAML, YAMLError
from packaging.version import Version from packaging.version import Version
from mautrix.client.api.types.util import SerializerError
from mautrix.types import SerializerError
from ..lib.zipimport import zipimporter, ZipImportError from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin from ..plugin_base import Plugin

View file

@ -13,13 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, Awaitable, Optional, Tuple from typing import Union, Awaitable, Optional, Tuple, List
from html import escape from html import escape
import asyncio
import attr import attr
from mautrix.client import Client as MatrixClient, SyncStream from mautrix.client import Client as MatrixClient, SyncStream
from mautrix.util.formatter import parse_html from mautrix.util import markdown, formatter
from mautrix.util import markdown
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent, from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
MessageType, TextMessageEventContent, Format, RelatesTo) MessageType, TextMessageEventContent, Format, RelatesTo)
@ -32,7 +33,7 @@ def parse_formatted(message: str, allow_html: bool = False, render_markdown: boo
html = message html = message
else: else:
return message, escape(message) return message, escape(message)
return parse_html(html), html return formatter.parse_html(html), html
class MaubotMessageEvent(MessageEvent): class MaubotMessageEvent(MessageEvent):
@ -110,12 +111,12 @@ class MaubotMatrixClient(MatrixClient):
content.set_edit(edits) content.set_edit(edits)
return self.send_message(room_id, content, **kwargs) return self.send_message(room_id, content, **kwargs)
async def dispatch_event(self, event: Event, source: SyncStream = SyncStream.INTERNAL) -> None: def dispatch_event(self, event: Event, source: SyncStream) -> List[asyncio.Task]:
if isinstance(event, MessageEvent): if isinstance(event, MessageEvent):
event = MaubotMessageEvent(event, self) event = MaubotMessageEvent(event, self)
elif source != SyncStream.INTERNAL: elif source != SyncStream.INTERNAL:
event.client = self event.client = self
return await super().dispatch_event(event, source) return super().dispatch_event(event, source)
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event: async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
event = await super().get_event(room_id, event_id) event = await super().get_event(room_id, event_id)

View file

@ -36,7 +36,7 @@ from .config import Config
from ..plugin_base import Plugin from ..plugin_base import Plugin
from ..loader import PluginMeta from ..loader import PluginMeta
from ..matrix import MaubotMatrixClient from ..matrix import MaubotMatrixClient
from ..lib.store_proxy import ClientStoreProxy from ..lib.store_proxy import SyncStoreProxy
from ..__meta__ import __version__ from ..__meta__ import __version__
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -143,7 +143,7 @@ async def main():
global client, bot global client, bot
client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token, client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token,
client_session=http_client, loop=loop, store=ClientStoreProxy(nb), client_session=http_client, loop=loop, store=SyncStoreProxy(nb),
log=logging.getLogger("maubot.client").getChild(user_id)) log=logging.getLogger("maubot.client").getChild(user_id))
while True: while True:

11
optional-requirements.txt Normal file
View file

@ -0,0 +1,11 @@
# Format: #/name defines a new extras_require group called name
# Uncommented lines after the group definition insert things into that group.
#/postgres
psycopg2-binary>=2,<3
#/e2be
asyncpg>=0.20,<0.21
python-olm>=3,<4
pycryptodome>=3,<4
unpaddedbase64>=1,<2

View file

@ -1,4 +1,4 @@
mautrix==0.6.0.beta7 mautrix==0.6.0rc1
aiohttp>=3,<4 aiohttp>=3,<4
SQLAlchemy>=1,<2 SQLAlchemy>=1,<2
alembic>=1,<2 alembic>=1,<2

View file

@ -5,6 +5,19 @@ import os
with open("requirements.txt") as reqs: with open("requirements.txt") as reqs:
install_requires = reqs.read().splitlines() install_requires = reqs.read().splitlines()
with open("optional-requirements.txt") as reqs:
extras_require = {}
current = []
for line in reqs.read().splitlines():
if line.startswith("#/"):
extras_require[line[2:]] = current = []
elif not line or line.startswith("#"):
continue
else:
current.append(line)
extras_require["all"] = list({dep for deps in extras_require.values() for dep in deps})
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "maubot", "__meta__.py") path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "maubot", "__meta__.py")
__version__ = "UNKNOWN" __version__ = "UNKNOWN"
with open(path) as f: with open(path) as f:
@ -25,6 +38,7 @@ setuptools.setup(
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
install_requires=install_requires, install_requires=install_requires,
extras_require=extras_require,
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",