Stop using SQLAlchemy ORM and add colorful logs

This commit is contained in:
Tulir Asokan 2019-09-01 14:46:08 +03:00
parent 59998b99b1
commit b59eab2953
8 changed files with 90 additions and 65 deletions

View file

@ -72,18 +72,21 @@ api_features:
logging:
version: 1
formatters:
precise:
colored:
(): maubot.lib.color_log.ColorFormatter
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
normal:
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: precise
filename: ./logs/maubot.log
formatter: normal
filename: ./maubot.log
maxBytes: 10485760
backupCount: 10
console:
class: logging.StreamHandler
formatter: precise
formatter: colored
loggers:
maubot:
level: DEBUG

View file

@ -56,11 +56,11 @@ log.info(f"Initializing maubot {__version__}")
loop = asyncio.get_event_loop()
init_zip_loader(config)
db_session = init_db(config)
clients = init_client_class(db_session, loop)
db_engine = init_db(config)
clients = init_client_class(loop)
management_api = init_mgmt_api(config, loop)
server = MaubotServer(management_api, config, loop)
plugins = init_plugin_instance_class(db_session, config, server, loop)
plugins = init_plugin_instance_class(config, server, loop)
for plugin in plugins:
plugin.load()
@ -69,30 +69,17 @@ signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
async def periodic_commit():
while True:
await asyncio.sleep(60)
db_session.commit()
periodic_commit_task: asyncio.Future = None
try:
log.info("Starting server")
loop.run_until_complete(server.start())
log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop))
log.info("Startup actions complete, running forever")
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
loop.run_forever()
except KeyboardInterrupt:
log.info("Interrupt received, stopping HTTP clients/servers and saving database")
if periodic_commit_task is not None:
periodic_commit_task.cancel()
log.debug("Stopping clients")
log.info("Interrupt received, stopping clients")
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()],
loop=loop))
db_session.commit()
if stop_log_listener is not None:
log.debug("Closing websockets")
loop.run_until_complete(stop_log_listener())

View file

@ -13,11 +13,10 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
import asyncio
import logging
from sqlalchemy.orm import Session
from aiohttp import ClientSession
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
@ -35,7 +34,6 @@ log = logging.getLogger("maubot.client")
class Client:
db: Session = None
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
@ -148,9 +146,7 @@ class Client:
def clear_cache(self) -> None:
self.stop_sync()
self.db_instance.filter_id = ""
self.db_instance.next_batch = ""
self.db.commit()
self.db_instance.edit(filter_id="", next_batch="")
self.start_sync()
def delete(self) -> None:
@ -158,8 +154,7 @@ class Client:
del self.cache[self.id]
except KeyError:
pass
self.db.delete(self.db_instance)
self.db.commit()
self.db_instance.delete()
def to_dict(self) -> dict:
return {
@ -183,14 +178,14 @@ class Client:
try:
return cls.cache[user_id]
except KeyError:
db_instance = db_instance or DBClient.query.get(user_id)
db_instance = db_instance or DBClient.get(user_id)
if not db_instance:
return None
return Client(db_instance)
@classmethod
def all(cls) -> List['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()]
def all(cls) -> Iterable['Client']:
return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_invite(self, evt: StrippedStateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
@ -314,8 +309,7 @@ class Client:
# endregion
def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
Client.db = db
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
return Client.all()

View file

@ -13,22 +13,19 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import cast
from typing import Iterable, Optional
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.engine.base import Engine
import sqlalchemy as sql
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
from mautrix.bridge.db import Base
from .config import Config
Base: declarative_base = declarative_base()
class DBPlugin(Base):
query: Query
__tablename__ = "plugin"
id: str = Column(String(255), primary_key=True)
@ -39,9 +36,16 @@ class DBPlugin(Base):
nullable=False)
config: str = Column(Text, nullable=False, default='')
@classmethod
def all(cls) -> Iterable['DBPlugin']:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional['DBPlugin']:
return cls._select_one_or_none(cls.c.id == id)
class DBClient(Base):
query: Query
__tablename__ = "client"
id: UserID = Column(String(255), primary_key=True)
@ -58,15 +62,23 @@ class DBClient(Base):
displayname: str = Column(String(255), nullable=False, default="")
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
@classmethod
def all(cls) -> Iterable['DBClient']:
return cls._select_all()
def init(config: Config) -> Session:
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
db_factory = sessionmaker(bind=db_engine)
db_session = scoped_session(db_factory)
@classmethod
def get(cls, id: str) -> Optional['DBClient']:
return cls._select_one_or_none(cls.c.id == id)
def init(config: Config) -> Engine:
db_engine = sql.create_engine(config["database"])
Base.metadata.bind = db_engine
Base.metadata.create_all()
DBPlugin.query = db_session.query_property()
DBClient.query = db_session.query_property()
for table in (DBPlugin, DBClient):
table.db = db_engine
table.t = table.__table__
table.c = table.t.c
table.column_names = table.c.keys()
return cast(Session, db_session)
return db_engine

View file

@ -13,16 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Dict, List, Optional, Iterable, TYPE_CHECKING
from asyncio import AbstractEventLoop
from aiohttp import web
import os.path
import logging
import io
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
from sqlalchemy.orm import Session
import sqlalchemy as sql
from mautrix.util.config import BaseProxyConfig, RecursiveDict
@ -44,7 +42,6 @@ yaml.indent(4)
class PluginInstance:
db: Session = None
webserver: 'MaubotServer' = None
mb_config: Config = None
loop: AbstractEventLoop = None
@ -130,8 +127,7 @@ class PluginInstance:
del self.cache[self.id]
except KeyError:
pass
self.db.delete(self.db_instance)
self.db.commit()
self.db_instance.delete()
if self.inst_db:
self.inst_db.dispose()
ZippedPluginLoader.trash(
@ -207,14 +203,14 @@ class PluginInstance:
try:
return cls.cache[instance_id]
except KeyError:
db_instance = db_instance or DBPlugin.query.get(instance_id)
db_instance = db_instance or DBPlugin.get(instance_id)
if not db_instance:
return None
return PluginInstance(db_instance)
@classmethod
def all(cls) -> List['PluginInstance']:
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
def all(cls) -> Iterable['PluginInstance']:
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
def update_id(self, new_id: str) -> None:
if new_id is not None and new_id != self.id:
@ -293,9 +289,8 @@ class PluginInstance:
# endregion
def init(db: Session, config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop) -> List[
PluginInstance]:
PluginInstance.db = db
def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop
) -> Iterable[PluginInstance]:
PluginInstance.mb_config = config
PluginInstance.loop = loop
PluginInstance.webserver = webserver

36
maubot/lib/color_log.py Normal file
View file

@ -0,0 +1,36 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.color_log import (ColorFormatter as BaseColorFormatter, PREFIX, MAU_COLOR,
MXID_COLOR, RESET)
INST_COLOR = PREFIX + "35m" # magenta
LOADER_COLOR = PREFIX + "36m" # blue
class ColorFormatter(BaseColorFormatter):
def _color_name(self, module: str) -> str:
client = "maubot.client"
if module.startswith(client):
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}"
instance = "maubot.instance"
if module.startswith(instance):
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
loader = "maubot.loader"
if module.startswith(loader):
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
if module.startswith("maubot"):
return f"{MAU_COLOR}{module}{RESET}"
return super()._color_name(module)

View file

@ -68,8 +68,7 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
displayname=data.get("displayname", ""),
avatar_url=data.get("avatar_url", ""))
client = Client(db_instance)
Client.db.add(db_instance)
Client.db.commit()
client.db_instance.insert()
await client.start()
return resp.created(client.to_dict())

View file

@ -56,8 +56,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
primary_user=primary_user, config=data.get("config", ""))
instance = PluginInstance(db_instance)
instance.load()
PluginInstance.db.add(db_instance)
PluginInstance.db.commit()
instance.db_instance.insert()
await instance.start()
return resp.created(instance.to_dict())