Add support for end-to-end encryption. Fixes #46
This commit is contained in:
parent
4e767a10e4
commit
69d7a4341b
17 changed files with 203 additions and 24 deletions
4
MANIFEST.in
Normal file
4
MANIFEST.in
Normal file
|
@ -0,0 +1,4 @@
|
|||
include README.md
|
||||
include LICENSE
|
||||
include requirements.txt
|
||||
include optional-requirements.txt
|
47
alembic/versions/90aa88820eab_add_matrix_state_store.py
Normal file
47
alembic/versions/90aa88820eab_add_matrix_state_store.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
"""Add Matrix state store
|
||||
|
||||
Revision ID: 90aa88820eab
|
||||
Revises: 4b93300852aa
|
||||
Create Date: 2020-07-12 01:50:06.215623
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from mautrix.client.state_store.sqlalchemy import SerializableType
|
||||
from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '90aa88820eab'
|
||||
down_revision = '4b93300852aa'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('mx_room_state',
|
||||
sa.Column('room_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('is_encrypted', sa.Boolean(), nullable=True),
|
||||
sa.Column('has_full_member_list', sa.Boolean(), nullable=True),
|
||||
sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True),
|
||||
sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True),
|
||||
sa.PrimaryKeyConstraint('room_id')
|
||||
)
|
||||
op.create_table('mx_user_profile',
|
||||
sa.Column('room_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False),
|
||||
sa.Column('displayname', sa.String(), nullable=True),
|
||||
sa.Column('avatar_url', sa.String(length=255), nullable=True),
|
||||
sa.PrimaryKeyConstraint('room_id', 'user_id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('mx_user_profile')
|
||||
op.drop_table('mx_room_state')
|
||||
# ### end Alembic commands ###
|
|
@ -5,6 +5,18 @@
|
|||
# Postgres: postgres://username:password@hostname/dbname
|
||||
database: sqlite:///maubot.db
|
||||
|
||||
# Database for encryption data.
|
||||
crypto_database:
|
||||
# Type of database. Either "default", "pickle" or "postgres".
|
||||
# When set to default, using SQLite as the main database will use pickle as the crypto database
|
||||
# and using Postgres as the main database will use the same one as the crypto database.
|
||||
#
|
||||
# When using pickle, individual crypto databases are stored in the pickle_dir directory.
|
||||
# When using non-default postgres, postgres_uri is used to connect to postgres.
|
||||
type: default
|
||||
postgres_uri: postgres://username:password@hostname/dbname
|
||||
pickle_dir: ./crypto
|
||||
|
||||
plugin_directories:
|
||||
# The directory where uploaded new plugins should be stored.
|
||||
upload: ./plugins
|
||||
|
|
|
@ -57,7 +57,7 @@ log.info(f"Initializing maubot {__version__}")
|
|||
|
||||
init_zip_loader(config)
|
||||
db_engine = init_db(config)
|
||||
clients = init_client_class(loop)
|
||||
clients = init_client_class(config, loop)
|
||||
management_api = init_mgmt_api(config, loop)
|
||||
server = MaubotServer(management_api, config, loop)
|
||||
plugins = init_plugin_instance_class(config, server, loop)
|
||||
|
@ -72,6 +72,9 @@ signal.signal(signal.SIGTERM, signal.default_int_handler)
|
|||
try:
|
||||
log.info("Starting server")
|
||||
loop.run_until_complete(server.start())
|
||||
if Client.crypto_db:
|
||||
log.debug("Starting client crypto database")
|
||||
loop.run_until_complete(Client.crypto_db.start())
|
||||
log.info("Starting clients and plugins")
|
||||
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
||||
log.info("Startup actions complete, running forever")
|
||||
|
|
|
@ -18,12 +18,13 @@ from io import BytesIO
|
|||
import zipfile
|
||||
import os
|
||||
|
||||
from mautrix.client.api.types.util import SerializerError
|
||||
from ruamel.yaml import YAML, YAMLError
|
||||
from colorama import Fore
|
||||
from PyInquirer import prompt
|
||||
import click
|
||||
|
||||
from mautrix.types import SerializerError
|
||||
|
||||
from ...loader import PluginMeta
|
||||
from ..cliq.validators import PathValidator
|
||||
from ..base import app
|
||||
|
|
|
@ -18,9 +18,10 @@ import asyncio
|
|||
|
||||
from colorama import Fore
|
||||
from aiohttp import WSMsgType, WSMessage, ClientSession
|
||||
from mautrix.client.api.types.util import Obj
|
||||
import click
|
||||
|
||||
from mautrix.types import Obj
|
||||
|
||||
from ..config import get_token
|
||||
from ..base import app
|
||||
|
||||
|
|
|
@ -13,24 +13,46 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
|
||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
|
||||
from os import path
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from yarl import URL
|
||||
|
||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
|
||||
PresenceState, StateFilter)
|
||||
from mautrix.client import InternalEventType
|
||||
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
||||
|
||||
from .lib.store_proxy import ClientStoreProxy
|
||||
from .lib.store_proxy import SyncStoreProxy
|
||||
from .db import DBClient
|
||||
from .matrix import MaubotMatrixClient
|
||||
|
||||
try:
|
||||
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
|
||||
PickleCryptoStore)
|
||||
|
||||
|
||||
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
||||
pass
|
||||
except ImportError:
|
||||
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
|
||||
SQLStateStore = BaseSQLStateStore
|
||||
|
||||
try:
|
||||
from mautrix.util.async_db import Database as AsyncDatabase
|
||||
from mautrix.crypto import PgCryptoStore
|
||||
except ImportError:
|
||||
AsyncDatabase = None
|
||||
PgCryptoStore = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .instance import PluginInstance
|
||||
from .config import Config
|
||||
|
||||
log = logging.getLogger("maubot.client")
|
||||
|
||||
|
@ -40,10 +62,15 @@ class Client:
|
|||
loop: asyncio.AbstractEventLoop = None
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
http_client: ClientSession = None
|
||||
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
|
||||
crypto_pickle_dir: str = None
|
||||
crypto_db: 'AsyncDatabase' = None
|
||||
|
||||
references: Set['PluginInstance']
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
crypto: Optional['OlmMachine']
|
||||
crypto_store: Optional['CryptoStore']
|
||||
started: bool
|
||||
|
||||
remote_displayname: Optional[str]
|
||||
|
@ -61,7 +88,15 @@ class Client:
|
|||
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||
token=self.access_token, client_session=self.http_client,
|
||||
log=self.log, loop=self.loop, device_id=self.device_id,
|
||||
store=ClientStoreProxy(self.db_instance))
|
||||
sync_store=SyncStoreProxy(self.db_instance),
|
||||
state_store=self.global_state_store)
|
||||
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
|
||||
self.crypto_store = self._make_crypto_store()
|
||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||
self.client.crypto = self.crypto
|
||||
else:
|
||||
self.crypto_store = None
|
||||
self.crypto = None
|
||||
self.client.ignore_initial_sync = True
|
||||
self.client.ignore_first_sync = True
|
||||
self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE
|
||||
|
@ -71,6 +106,14 @@ class Client:
|
|||
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
||||
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
||||
|
||||
def _make_crypto_store(self) -> 'CryptoStore':
|
||||
if self.crypto_db:
|
||||
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
|
||||
elif self.crypto_pickle_dir:
|
||||
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
|
||||
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
|
||||
raise ValueError("Crypto database not configured")
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: Dict[str, Any]) -> None:
|
||||
self.sync_ok = ok
|
||||
|
@ -130,6 +173,16 @@ class Client:
|
|||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
if self.crypto:
|
||||
self.log.debug("Enabling end-to-end encryption support")
|
||||
await self.crypto_store.open()
|
||||
crypto_device_id = await self.crypto_store.get_device_id()
|
||||
if crypto_device_id and crypto_device_id != self.device_id:
|
||||
self.log.warning("Mismatching device ID in crypto store and main database. "
|
||||
"Encryption may not work.")
|
||||
await self.crypto.load()
|
||||
if not crypto_device_id:
|
||||
await self.crypto_store.put_device_id(self.device_id)
|
||||
self.start_sync()
|
||||
await self._update_remote_profile()
|
||||
self.started = True
|
||||
|
@ -154,6 +207,8 @@ class Client:
|
|||
self.started = False
|
||||
await self.stop_plugins()
|
||||
self.stop_sync()
|
||||
if self.crypto:
|
||||
await self.crypto_store.close()
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
self.stop_sync()
|
||||
|
@ -172,6 +227,7 @@ class Client:
|
|||
"id": self.id,
|
||||
"homeserver": self.homeserver,
|
||||
"access_token": self.access_token,
|
||||
"device_id": self.device_id,
|
||||
"enabled": self.enabled,
|
||||
"started": self.started,
|
||||
"sync": self.sync,
|
||||
|
@ -243,11 +299,12 @@ class Client:
|
|||
return
|
||||
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
||||
token=access_token or self.access_token, loop=self.loop,
|
||||
client_session=self.http_client, log=self.log)
|
||||
client_session=self.http_client, device_id=self.device_id,
|
||||
log=self.log, state_store=self.global_state_store)
|
||||
mxid = await new_client.whoami()
|
||||
if mxid != self.id:
|
||||
raise ValueError(f"MXID mismatch: {mxid}")
|
||||
new_client.store = self.db_instance
|
||||
new_client.sync_store = self.db_instance
|
||||
self.stop_sync()
|
||||
self.client = new_client
|
||||
self.db_instance.homeserver = homeserver
|
||||
|
@ -341,7 +398,30 @@ class Client:
|
|||
# endregion
|
||||
|
||||
|
||||
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
|
||||
if OlmMachine:
|
||||
db_type = config["crypto_database.type"]
|
||||
if db_type == "default":
|
||||
db_url = config["database"]
|
||||
parsed_url = URL(db_url)
|
||||
if parsed_url.scheme == "sqlite":
|
||||
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
|
||||
elif parsed_url.scheme == "postgres":
|
||||
if not PgCryptoStore:
|
||||
log.warning("Default database is postgres, but asyncpg is not installed. "
|
||||
"Encryption will not work.")
|
||||
else:
|
||||
Client.crypto_db = AsyncDatabase(url=db_url,
|
||||
upgrade_table=PgCryptoStore.upgrade_table)
|
||||
elif db_type == "pickle":
|
||||
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
|
||||
elif db_type == "postgres" and PgCryptoStore:
|
||||
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
|
||||
upgrade_table=PgCryptoStore.upgrade_table)
|
||||
else:
|
||||
raise ValueError("Unsupported crypto database type")
|
||||
|
||||
return Client.all()
|
||||
|
|
|
@ -32,6 +32,9 @@ class Config(BaseFileConfig):
|
|||
base = helper.base
|
||||
copy = helper.copy
|
||||
copy("database")
|
||||
copy("crypto_database.type")
|
||||
copy("crypto_database.postgres_uri")
|
||||
copy("crypto_database.pickle_dir")
|
||||
copy("plugin_directories.upload")
|
||||
copy("plugin_directories.load")
|
||||
copy("plugin_directories.trash")
|
||||
|
|
|
@ -23,6 +23,7 @@ import sqlalchemy as sql
|
|||
|
||||
from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI
|
||||
from mautrix.util.db import Base
|
||||
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
|
||||
|
||||
from .config import Config
|
||||
|
||||
|
@ -79,7 +80,7 @@ def init(config: Config) -> Engine:
|
|||
db = sql.create_engine(config["database"])
|
||||
Base.metadata.bind = db
|
||||
|
||||
for table in (DBPlugin, DBClient):
|
||||
for table in (DBPlugin, DBClient, RoomState, UserProfile):
|
||||
table.bind(db)
|
||||
|
||||
if not db.has_table("alembic_version"):
|
||||
|
|
|
@ -13,11 +13,11 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from mautrix.client import ClientStore
|
||||
from mautrix.client import SyncStore
|
||||
from mautrix.types import SyncToken
|
||||
|
||||
|
||||
class ClientStoreProxy(ClientStore):
|
||||
class SyncStoreProxy(SyncStore):
|
||||
def __init__(self, db_instance) -> None:
|
||||
self.db_instance = db_instance
|
||||
|
||||
|
|
|
@ -19,8 +19,8 @@ import asyncio
|
|||
|
||||
from attr import dataclass
|
||||
from packaging.version import Version, InvalidVersion
|
||||
from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer,
|
||||
deserializer)
|
||||
|
||||
from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer
|
||||
|
||||
from ..__meta__ import __version__
|
||||
from ..plugin_base import Plugin
|
||||
|
|
|
@ -22,7 +22,8 @@ import os
|
|||
|
||||
from ruamel.yaml import YAML, YAMLError
|
||||
from packaging.version import Version
|
||||
from mautrix.client.api.types.util import SerializerError
|
||||
|
||||
from mautrix.types import SerializerError
|
||||
|
||||
from ..lib.zipimport import zipimporter, ZipImportError
|
||||
from ..plugin_base import Plugin
|
||||
|
|
|
@ -13,13 +13,14 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Union, Awaitable, Optional, Tuple
|
||||
from typing import Union, Awaitable, Optional, Tuple, List
|
||||
from html import escape
|
||||
import asyncio
|
||||
|
||||
import attr
|
||||
|
||||
from mautrix.client import Client as MatrixClient, SyncStream
|
||||
from mautrix.util.formatter import parse_html
|
||||
from mautrix.util import markdown
|
||||
from mautrix.util import markdown, formatter
|
||||
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
|
||||
MessageType, TextMessageEventContent, Format, RelatesTo)
|
||||
|
||||
|
@ -32,7 +33,7 @@ def parse_formatted(message: str, allow_html: bool = False, render_markdown: boo
|
|||
html = message
|
||||
else:
|
||||
return message, escape(message)
|
||||
return parse_html(html), html
|
||||
return formatter.parse_html(html), html
|
||||
|
||||
|
||||
class MaubotMessageEvent(MessageEvent):
|
||||
|
@ -110,12 +111,12 @@ class MaubotMatrixClient(MatrixClient):
|
|||
content.set_edit(edits)
|
||||
return self.send_message(room_id, content, **kwargs)
|
||||
|
||||
async def dispatch_event(self, event: Event, source: SyncStream = SyncStream.INTERNAL) -> None:
|
||||
def dispatch_event(self, event: Event, source: SyncStream) -> List[asyncio.Task]:
|
||||
if isinstance(event, MessageEvent):
|
||||
event = MaubotMessageEvent(event, self)
|
||||
elif source != SyncStream.INTERNAL:
|
||||
event.client = self
|
||||
return await super().dispatch_event(event, source)
|
||||
return super().dispatch_event(event, source)
|
||||
|
||||
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
|
||||
event = await super().get_event(room_id, event_id)
|
||||
|
|
|
@ -36,7 +36,7 @@ from .config import Config
|
|||
from ..plugin_base import Plugin
|
||||
from ..loader import PluginMeta
|
||||
from ..matrix import MaubotMatrixClient
|
||||
from ..lib.store_proxy import ClientStoreProxy
|
||||
from ..lib.store_proxy import SyncStoreProxy
|
||||
from ..__meta__ import __version__
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -143,7 +143,7 @@ async def main():
|
|||
global client, bot
|
||||
|
||||
client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token,
|
||||
client_session=http_client, loop=loop, store=ClientStoreProxy(nb),
|
||||
client_session=http_client, loop=loop, store=SyncStoreProxy(nb),
|
||||
log=logging.getLogger("maubot.client").getChild(user_id))
|
||||
|
||||
while True:
|
||||
|
|
11
optional-requirements.txt
Normal file
11
optional-requirements.txt
Normal file
|
@ -0,0 +1,11 @@
|
|||
# Format: #/name defines a new extras_require group called name
|
||||
# Uncommented lines after the group definition insert things into that group.
|
||||
|
||||
#/postgres
|
||||
psycopg2-binary>=2,<3
|
||||
|
||||
#/e2be
|
||||
asyncpg>=0.20,<0.21
|
||||
python-olm>=3,<4
|
||||
pycryptodome>=3,<4
|
||||
unpaddedbase64>=1,<2
|
|
@ -1,4 +1,4 @@
|
|||
mautrix==0.6.0.beta7
|
||||
mautrix==0.6.0rc1
|
||||
aiohttp>=3,<4
|
||||
SQLAlchemy>=1,<2
|
||||
alembic>=1,<2
|
||||
|
|
14
setup.py
14
setup.py
|
@ -5,6 +5,19 @@ import os
|
|||
with open("requirements.txt") as reqs:
|
||||
install_requires = reqs.read().splitlines()
|
||||
|
||||
with open("optional-requirements.txt") as reqs:
|
||||
extras_require = {}
|
||||
current = []
|
||||
for line in reqs.read().splitlines():
|
||||
if line.startswith("#/"):
|
||||
extras_require[line[2:]] = current = []
|
||||
elif not line or line.startswith("#"):
|
||||
continue
|
||||
else:
|
||||
current.append(line)
|
||||
|
||||
extras_require["all"] = list({dep for deps in extras_require.values() for dep in deps})
|
||||
|
||||
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "maubot", "__meta__.py")
|
||||
__version__ = "UNKNOWN"
|
||||
with open(path) as f:
|
||||
|
@ -25,6 +38,7 @@ setuptools.setup(
|
|||
packages=setuptools.find_packages(),
|
||||
|
||||
install_requires=install_requires,
|
||||
extras_require=extras_require,
|
||||
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
|
|
Loading…
Reference in a new issue