Stop using SQLAlchemy ORM and add colorful logs
This commit is contained in:
parent
59998b99b1
commit
b59eab2953
8 changed files with 90 additions and 65 deletions
|
@ -72,18 +72,21 @@ api_features:
|
|||
logging:
|
||||
version: 1
|
||||
formatters:
|
||||
precise:
|
||||
colored:
|
||||
(): maubot.lib.color_log.ColorFormatter
|
||||
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
|
||||
normal:
|
||||
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
|
||||
handlers:
|
||||
file:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
formatter: precise
|
||||
filename: ./logs/maubot.log
|
||||
formatter: normal
|
||||
filename: ./maubot.log
|
||||
maxBytes: 10485760
|
||||
backupCount: 10
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
formatter: colored
|
||||
loggers:
|
||||
maubot:
|
||||
level: DEBUG
|
||||
|
|
|
@ -56,11 +56,11 @@ log.info(f"Initializing maubot {__version__}")
|
|||
loop = asyncio.get_event_loop()
|
||||
|
||||
init_zip_loader(config)
|
||||
db_session = init_db(config)
|
||||
clients = init_client_class(db_session, loop)
|
||||
db_engine = init_db(config)
|
||||
clients = init_client_class(loop)
|
||||
management_api = init_mgmt_api(config, loop)
|
||||
server = MaubotServer(management_api, config, loop)
|
||||
plugins = init_plugin_instance_class(db_session, config, server, loop)
|
||||
plugins = init_plugin_instance_class(config, server, loop)
|
||||
|
||||
for plugin in plugins:
|
||||
plugin.load()
|
||||
|
@ -69,30 +69,17 @@ signal.signal(signal.SIGINT, signal.default_int_handler)
|
|||
signal.signal(signal.SIGTERM, signal.default_int_handler)
|
||||
|
||||
|
||||
async def periodic_commit():
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
periodic_commit_task: asyncio.Future = None
|
||||
|
||||
try:
|
||||
log.info("Starting server")
|
||||
loop.run_until_complete(server.start())
|
||||
log.info("Starting clients and plugins")
|
||||
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients], loop=loop))
|
||||
log.info("Startup actions complete, running forever")
|
||||
periodic_commit_task = asyncio.ensure_future(periodic_commit(), loop=loop)
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
log.info("Interrupt received, stopping HTTP clients/servers and saving database")
|
||||
if periodic_commit_task is not None:
|
||||
periodic_commit_task.cancel()
|
||||
log.debug("Stopping clients")
|
||||
log.info("Interrupt received, stopping clients")
|
||||
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()],
|
||||
loop=loop))
|
||||
db_session.commit()
|
||||
if stop_log_listener is not None:
|
||||
log.debug("Closing websockets")
|
||||
loop.run_until_complete(stop_log_listener())
|
||||
|
|
|
@ -13,11 +13,10 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
|
||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||
|
@ -35,7 +34,6 @@ log = logging.getLogger("maubot.client")
|
|||
|
||||
|
||||
class Client:
|
||||
db: Session = None
|
||||
log: logging.Logger = None
|
||||
loop: asyncio.AbstractEventLoop = None
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
|
@ -148,9 +146,7 @@ class Client:
|
|||
|
||||
def clear_cache(self) -> None:
|
||||
self.stop_sync()
|
||||
self.db_instance.filter_id = ""
|
||||
self.db_instance.next_batch = ""
|
||||
self.db.commit()
|
||||
self.db_instance.edit(filter_id="", next_batch="")
|
||||
self.start_sync()
|
||||
|
||||
def delete(self) -> None:
|
||||
|
@ -158,8 +154,7 @@ class Client:
|
|||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.db.delete(self.db_instance)
|
||||
self.db.commit()
|
||||
self.db_instance.delete()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
@ -183,14 +178,14 @@ class Client:
|
|||
try:
|
||||
return cls.cache[user_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBClient.query.get(user_id)
|
||||
db_instance = db_instance or DBClient.get(user_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return Client(db_instance)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> List['Client']:
|
||||
return [cls.get(user.id, user) for user in DBClient.query.all()]
|
||||
def all(cls) -> Iterable['Client']:
|
||||
return (cls.get(user.id, user) for user in DBClient.all())
|
||||
|
||||
async def _handle_invite(self, evt: StrippedStateEvent) -> None:
|
||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||
|
@ -314,8 +309,7 @@ class Client:
|
|||
# endregion
|
||||
|
||||
|
||||
def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
|
||||
Client.db = db
|
||||
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
return Client.all()
|
||||
|
|
42
maubot/db.py
42
maubot/db.py
|
@ -13,22 +13,19 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import cast
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
|
||||
from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.engine.base import Engine
|
||||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
|
||||
from mautrix.bridge.db import Base
|
||||
|
||||
from .config import Config
|
||||
|
||||
Base: declarative_base = declarative_base()
|
||||
|
||||
|
||||
class DBPlugin(Base):
|
||||
query: Query
|
||||
__tablename__ = "plugin"
|
||||
|
||||
id: str = Column(String(255), primary_key=True)
|
||||
|
@ -39,9 +36,16 @@ class DBPlugin(Base):
|
|||
nullable=False)
|
||||
config: str = Column(Text, nullable=False, default='')
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable['DBPlugin']:
|
||||
return cls._select_all()
|
||||
|
||||
@classmethod
|
||||
def get(cls, id: str) -> Optional['DBPlugin']:
|
||||
return cls._select_one_or_none(cls.c.id == id)
|
||||
|
||||
|
||||
class DBClient(Base):
|
||||
query: Query
|
||||
__tablename__ = "client"
|
||||
|
||||
id: UserID = Column(String(255), primary_key=True)
|
||||
|
@ -58,15 +62,23 @@ class DBClient(Base):
|
|||
displayname: str = Column(String(255), nullable=False, default="")
|
||||
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable['DBClient']:
|
||||
return cls._select_all()
|
||||
|
||||
def init(config: Config) -> Session:
|
||||
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
|
||||
db_factory = sessionmaker(bind=db_engine)
|
||||
db_session = scoped_session(db_factory)
|
||||
@classmethod
|
||||
def get(cls, id: str) -> Optional['DBClient']:
|
||||
return cls._select_one_or_none(cls.c.id == id)
|
||||
|
||||
|
||||
def init(config: Config) -> Engine:
|
||||
db_engine = sql.create_engine(config["database"])
|
||||
Base.metadata.bind = db_engine
|
||||
Base.metadata.create_all()
|
||||
|
||||
DBPlugin.query = db_session.query_property()
|
||||
DBClient.query = db_session.query_property()
|
||||
for table in (DBPlugin, DBClient):
|
||||
table.db = db_engine
|
||||
table.t = table.__table__
|
||||
table.c = table.t.c
|
||||
table.column_names = table.c.keys()
|
||||
|
||||
return cast(Session, db_session)
|
||||
return db_engine
|
||||
|
|
|
@ -13,16 +13,14 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, Iterable, TYPE_CHECKING
|
||||
from asyncio import AbstractEventLoop
|
||||
from aiohttp import web
|
||||
import os.path
|
||||
import logging
|
||||
import io
|
||||
|
||||
from ruamel.yaml.comments import CommentedMap
|
||||
from ruamel.yaml import YAML
|
||||
from sqlalchemy.orm import Session
|
||||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.util.config import BaseProxyConfig, RecursiveDict
|
||||
|
@ -44,7 +42,6 @@ yaml.indent(4)
|
|||
|
||||
|
||||
class PluginInstance:
|
||||
db: Session = None
|
||||
webserver: 'MaubotServer' = None
|
||||
mb_config: Config = None
|
||||
loop: AbstractEventLoop = None
|
||||
|
@ -130,8 +127,7 @@ class PluginInstance:
|
|||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.db.delete(self.db_instance)
|
||||
self.db.commit()
|
||||
self.db_instance.delete()
|
||||
if self.inst_db:
|
||||
self.inst_db.dispose()
|
||||
ZippedPluginLoader.trash(
|
||||
|
@ -207,14 +203,14 @@ class PluginInstance:
|
|||
try:
|
||||
return cls.cache[instance_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBPlugin.query.get(instance_id)
|
||||
db_instance = db_instance or DBPlugin.get(instance_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return PluginInstance(db_instance)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> List['PluginInstance']:
|
||||
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
|
||||
def all(cls) -> Iterable['PluginInstance']:
|
||||
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
|
||||
|
||||
def update_id(self, new_id: str) -> None:
|
||||
if new_id is not None and new_id != self.id:
|
||||
|
@ -293,9 +289,8 @@ class PluginInstance:
|
|||
# endregion
|
||||
|
||||
|
||||
def init(db: Session, config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop) -> List[
|
||||
PluginInstance]:
|
||||
PluginInstance.db = db
|
||||
def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop
|
||||
) -> Iterable[PluginInstance]:
|
||||
PluginInstance.mb_config = config
|
||||
PluginInstance.loop = loop
|
||||
PluginInstance.webserver = webserver
|
||||
|
|
36
maubot/lib/color_log.py
Normal file
36
maubot/lib/color_log.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from mautrix.util.color_log import (ColorFormatter as BaseColorFormatter, PREFIX, MAU_COLOR,
|
||||
MXID_COLOR, RESET)
|
||||
|
||||
INST_COLOR = PREFIX + "35m" # magenta
|
||||
LOADER_COLOR = PREFIX + "36m" # blue
|
||||
|
||||
|
||||
class ColorFormatter(BaseColorFormatter):
|
||||
def _color_name(self, module: str) -> str:
|
||||
client = "maubot.client"
|
||||
if module.startswith(client):
|
||||
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}"
|
||||
instance = "maubot.instance"
|
||||
if module.startswith(instance):
|
||||
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
|
||||
loader = "maubot.loader"
|
||||
if module.startswith(loader):
|
||||
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
|
||||
if module.startswith("maubot"):
|
||||
return f"{MAU_COLOR}{module}{RESET}"
|
||||
return super()._color_name(module)
|
|
@ -68,8 +68,7 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
|
|||
displayname=data.get("displayname", ""),
|
||||
avatar_url=data.get("avatar_url", ""))
|
||||
client = Client(db_instance)
|
||||
Client.db.add(db_instance)
|
||||
Client.db.commit()
|
||||
client.db_instance.insert()
|
||||
await client.start()
|
||||
return resp.created(client.to_dict())
|
||||
|
||||
|
|
|
@ -56,8 +56,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
|
|||
primary_user=primary_user, config=data.get("config", ""))
|
||||
instance = PluginInstance(db_instance)
|
||||
instance.load()
|
||||
PluginInstance.db.add(db_instance)
|
||||
PluginInstance.db.commit()
|
||||
instance.db_instance.insert()
|
||||
await instance.start()
|
||||
return resp.created(instance.to_dict())
|
||||
|
||||
|
|
Loading…
Reference in a new issue