Add command matching stuff
This commit is contained in:
parent
c79ed97a47
commit
0b246e44a8
7 changed files with 196 additions and 34 deletions
|
@ -1,3 +1,4 @@
|
||||||
from .plugin_base import Plugin
|
from .plugin_base import Plugin
|
||||||
from .command_spec import CommandSpec, Command, PassiveCommand, Argument
|
from .command_spec import CommandSpec, Command, PassiveCommand, Argument
|
||||||
from .event import FakeEvent as Event
|
from .event import FakeEvent as Event
|
||||||
|
from .client import MaubotMatrixClient as Client
|
||||||
|
|
|
@ -13,33 +13,80 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional, Union, Callable
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from mautrix import Client as MatrixClient
|
from mautrix import Client as MatrixClient
|
||||||
|
from mautrix.client import EventHandler
|
||||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
|
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
|
||||||
EventType)
|
EventType, MessageEvent)
|
||||||
|
|
||||||
|
from .command_spec import ParsedCommand
|
||||||
from .db import DBClient
|
from .db import DBClient
|
||||||
|
|
||||||
log = logging.getLogger("maubot.client")
|
log = logging.getLogger("maubot.client")
|
||||||
|
|
||||||
|
|
||||||
|
class MaubotMatrixClient(MatrixClient):
|
||||||
|
def __init__(self, maubot_client: 'Client', *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._maubot_client = maubot_client
|
||||||
|
self.command_handlers: Dict[str, List[EventHandler]] = {}
|
||||||
|
self.commands: List[ParsedCommand] = []
|
||||||
|
|
||||||
|
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
|
||||||
|
|
||||||
|
async def _command_event_handler(self, evt: MessageEvent) -> None:
|
||||||
|
for command in self.commands:
|
||||||
|
if command.match(evt):
|
||||||
|
await self._trigger_command(command, evt)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
|
||||||
|
for handler in self.command_handlers.get(command.name, []):
|
||||||
|
await handler(evt)
|
||||||
|
|
||||||
|
def on(self, var: Union[EventHandler, EventType, str]
|
||||||
|
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
|
||||||
|
if isinstance(var, str):
|
||||||
|
def decorator(func: EventHandler) -> EventHandler:
|
||||||
|
self.add_command_handler(var, func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
return super().on(var)
|
||||||
|
|
||||||
|
def add_command_handler(self, command: str, handler: EventHandler) -> None:
|
||||||
|
self.command_handlers.setdefault(command, []).append(handler)
|
||||||
|
|
||||||
|
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
|
||||||
|
try:
|
||||||
|
self.command_handlers[command].remove(handler)
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
cache: Dict[UserID, 'Client'] = {}
|
cache: Dict[UserID, 'Client'] = {}
|
||||||
http_client: ClientSession = None
|
http_client: ClientSession = None
|
||||||
|
|
||||||
|
db_instance: DBClient
|
||||||
|
client: MaubotMatrixClient
|
||||||
|
|
||||||
def __init__(self, db_instance: DBClient) -> None:
|
def __init__(self, db_instance: DBClient) -> None:
|
||||||
self.db_instance: DBClient = db_instance
|
self.db_instance = db_instance
|
||||||
self.cache[self.id] = self
|
self.cache[self.id] = self
|
||||||
self.client: MatrixClient = MatrixClient(mxid=self.id,
|
self.client = MaubotMatrixClient(maubot_client=self,
|
||||||
base_url=self.homeserver,
|
store=self.db_instance,
|
||||||
token=self.access_token,
|
mxid=self.id,
|
||||||
client_session=self.http_client,
|
base_url=self.homeserver,
|
||||||
log=log.getChild(self.id))
|
token=self.access_token,
|
||||||
|
client_session=self.http_client,
|
||||||
|
log=log.getChild(self.id))
|
||||||
if self.autojoin:
|
if self.autojoin:
|
||||||
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
|
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, user_id: UserID) -> Optional['Client']:
|
def get(cls, user_id: UserID) -> Optional['Client']:
|
||||||
|
@ -103,9 +150,9 @@ class Client:
|
||||||
if value == self.db_instance.autojoin:
|
if value == self.db_instance.autojoin:
|
||||||
return
|
return
|
||||||
if value:
|
if value:
|
||||||
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
|
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||||
else:
|
else:
|
||||||
self.client.remove_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
|
self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||||
self.db_instance.autojoin = value
|
self.db_instance.autojoin = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -126,6 +173,6 @@ class Client:
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
async def handle_invite(self, evt: StateEvent) -> None:
|
async def _handle_invite(self, evt: StateEvent) -> None:
|
||||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||||
await self.client.join_room_by_id(evt.room_id)
|
await self.client.join_room_by_id(evt.room_id)
|
||||||
|
|
|
@ -13,10 +13,11 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Pattern, Union, Tuple, Optional, Any
|
||||||
from attr import dataclass
|
from attr import dataclass
|
||||||
|
import re
|
||||||
|
|
||||||
from mautrix.types import Event
|
from mautrix.types import MessageEvent, MatchedCommand, MatchedPassiveCommand
|
||||||
from mautrix.client.api.types.util import SerializableAttrs
|
from mautrix.client.api.types.util import SerializableAttrs
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,7 +40,103 @@ class PassiveCommand(SerializableAttrs['PassiveCommand']):
|
||||||
name: str
|
name: str
|
||||||
matches: str
|
matches: str
|
||||||
match_against: str
|
match_against: str
|
||||||
match_event: Event = None
|
match_event: MessageEvent = None
|
||||||
|
|
||||||
|
|
||||||
|
class ParsedCommand:
|
||||||
|
name: str
|
||||||
|
is_passive: bool
|
||||||
|
arguments: List[str]
|
||||||
|
starts_with: str
|
||||||
|
matches: Pattern
|
||||||
|
match_against: str
|
||||||
|
match_event: MessageEvent
|
||||||
|
|
||||||
|
def __init__(self, command: Union[PassiveCommand, Command]) -> None:
|
||||||
|
if isinstance(command, PassiveCommand):
|
||||||
|
self._init_passive(command)
|
||||||
|
elif isinstance(command, Command):
|
||||||
|
self._init_active(command)
|
||||||
|
else:
|
||||||
|
raise ValueError("Command parameter must be a Command or a PassiveCommand.")
|
||||||
|
|
||||||
|
def _init_passive(self, command: PassiveCommand) -> None:
|
||||||
|
self.name = command.name
|
||||||
|
self.is_passive = True
|
||||||
|
self.match_against = command.match_against
|
||||||
|
self.matches = re.compile(command.matches)
|
||||||
|
self.match_event = command.match_event
|
||||||
|
|
||||||
|
def _init_active(self, command: Command) -> None:
|
||||||
|
self.name = command.syntax
|
||||||
|
self.is_passive = False
|
||||||
|
|
||||||
|
regex_builder = []
|
||||||
|
sw_builder = []
|
||||||
|
argument_encountered = False
|
||||||
|
|
||||||
|
for word in command.syntax.split(" "):
|
||||||
|
arg = command.arguments.get(word, None)
|
||||||
|
if arg is not None and len(word) > 0:
|
||||||
|
argument_encountered = True
|
||||||
|
regex = f"({arg.matches})" if arg.required else f"(?:{arg.matches})?"
|
||||||
|
self.arguments.append(word)
|
||||||
|
regex_builder.append(regex)
|
||||||
|
else:
|
||||||
|
if not argument_encountered:
|
||||||
|
sw_builder.append(word)
|
||||||
|
regex_builder.append(re.escape(word))
|
||||||
|
self.starts_with = "!" + " ".join(sw_builder)
|
||||||
|
self.matches = re.compile("^!" + " ".join(regex_builder) + "$")
|
||||||
|
self.match_against = "body"
|
||||||
|
|
||||||
|
def match(self, evt: MessageEvent) -> bool:
|
||||||
|
return self._match_passive(evt) if self.is_passive else self._match_active(evt)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
|
||||||
|
if '.' not in key:
|
||||||
|
return key, None
|
||||||
|
key, next_key = key.split('.', 1)
|
||||||
|
if len(key) > 0 and key[0] == "[":
|
||||||
|
end_index = next_key.index("]")
|
||||||
|
key = key[1:] + "." + next_key[:end_index]
|
||||||
|
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
|
||||||
|
return key, next_key
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _recursive_get(cls, data: Any, key: str) -> Any:
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
key, next_key = cls._parse_key(key)
|
||||||
|
if next_key is not None:
|
||||||
|
return cls._recursive_get(data[key], next_key)
|
||||||
|
return data[key]
|
||||||
|
|
||||||
|
def _match_passive(self, evt: MessageEvent) -> bool:
|
||||||
|
try:
|
||||||
|
match_against = self._recursive_get(evt.content, self.match_against)
|
||||||
|
except KeyError:
|
||||||
|
match_against = None
|
||||||
|
match_against = match_against or evt.content.body
|
||||||
|
matches = [[match.string[match.start():match.end()]] + list(match.groups())
|
||||||
|
for match in self.matches.finditer(match_against)]
|
||||||
|
if not matches:
|
||||||
|
return False
|
||||||
|
if evt.unsigned.passive_command is None:
|
||||||
|
evt.unsigned.passive_command = {}
|
||||||
|
evt.unsigned.passive_command[self.name] = MatchedPassiveCommand(captured=matches)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _match_active(self, evt: MessageEvent) -> bool:
|
||||||
|
if not evt.content.body.startswith(self.starts_with):
|
||||||
|
return False
|
||||||
|
match = self.matches.match(evt.content.body)
|
||||||
|
if not match:
|
||||||
|
return False
|
||||||
|
evt.content.command = MatchedCommand(matched=self.name,
|
||||||
|
arguments=dict(zip(self.arguments, match.groups())))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -50,3 +147,6 @@ class CommandSpec(SerializableAttrs['CommandSpec']):
|
||||||
def __add__(self, other: 'CommandSpec') -> 'CommandSpec':
|
def __add__(self, other: 'CommandSpec') -> 'CommandSpec':
|
||||||
return CommandSpec(commands=self.commands + other.commands,
|
return CommandSpec(commands=self.commands + other.commands,
|
||||||
passive_commands=self.passive_commands + other.passive_commands)
|
passive_commands=self.passive_commands + other.passive_commands)
|
||||||
|
|
||||||
|
def parse(self) -> List[ParsedCommand]:
|
||||||
|
return [ParsedCommand(command) for command in self.commands + self.passive_commands]
|
||||||
|
|
42
maubot/db.py
42
maubot/db.py
|
@ -13,31 +13,39 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from typing import Type
|
||||||
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
|
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
|
||||||
from sqlalchemy.orm import Query
|
from sqlalchemy.orm import Query
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from mautrix.types import JSON, UserID, FilterID, SyncToken, ContentURI
|
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
|
||||||
|
from mautrix.client.api.types.util import Serializable
|
||||||
|
from mautrix import ClientStore
|
||||||
|
|
||||||
|
from .command_spec import CommandSpec
|
||||||
|
|
||||||
Base: declarative_base = declarative_base()
|
Base: declarative_base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class JSONEncodedDict(TypeDecorator):
|
def make_serializable_alchemy(serializable_type: Type[Serializable]):
|
||||||
impl = Text
|
class SerializableAlchemy(TypeDecorator):
|
||||||
|
impl = Text
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def python_type(self):
|
def python_type(self):
|
||||||
return dict
|
return serializable_type
|
||||||
|
|
||||||
def process_literal_param(self, value, _):
|
def process_literal_param(self, value: Serializable, _) -> str:
|
||||||
return json.dumps(value) if value is not None else None
|
return json.dumps(value.serialize()) if value is not None else None
|
||||||
|
|
||||||
def process_bind_param(self, value, _):
|
def process_bind_param(self, value: Serializable, _) -> str:
|
||||||
return json.dumps(value) if value is not None else None
|
return json.dumps(value.serialize()) if value is not None else None
|
||||||
|
|
||||||
def process_result_value(self, value, _):
|
def process_result_value(self, value: str, _) -> serializable_type:
|
||||||
return json.loads(value) if value is not None else None
|
return serializable_type.deserialize(json.loads(value)) if value is not None else None
|
||||||
|
|
||||||
|
return SerializableAlchemy
|
||||||
|
|
||||||
|
|
||||||
class DBPlugin(Base):
|
class DBPlugin(Base):
|
||||||
|
@ -52,7 +60,7 @@ class DBPlugin(Base):
|
||||||
nullable=False)
|
nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class DBClient(Base):
|
class DBClient(ClientStore, Base):
|
||||||
query: Query
|
query: Query
|
||||||
__tablename__ = "client"
|
__tablename__ = "client"
|
||||||
|
|
||||||
|
@ -74,10 +82,10 @@ class DBCommandSpec(Base):
|
||||||
query: Query
|
query: Query
|
||||||
__tablename__ = "command_spec"
|
__tablename__ = "command_spec"
|
||||||
|
|
||||||
owner: str = Column(String(255),
|
plugin: str = Column(String(255),
|
||||||
ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"),
|
ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"),
|
||||||
primary_key=True)
|
primary_key=True)
|
||||||
client: UserID = Column(String(255),
|
client: UserID = Column(String(255),
|
||||||
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
|
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
|
||||||
primary_key=True)
|
primary_key=True)
|
||||||
spec: JSON = Column(JSONEncodedDict, nullable=False)
|
spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False)
|
||||||
|
|
|
@ -17,13 +17,17 @@ from typing import TYPE_CHECKING
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from mautrix import Client as MatrixClient
|
from .client import MaubotMatrixClient
|
||||||
|
from .command_spec import CommandSpec
|
||||||
|
|
||||||
|
|
||||||
class Plugin(ABC):
|
class Plugin(ABC):
|
||||||
def __init__(self, client: 'MatrixClient') -> None:
|
def __init__(self, client: 'MaubotMatrixClient') -> None:
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
|
def set_command_spec(self, spec: 'CommandSpec') -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -4,3 +4,4 @@ SQLAlchemy
|
||||||
alembic
|
alembic
|
||||||
commonmark
|
commonmark
|
||||||
ruamel.yaml
|
ruamel.yaml
|
||||||
|
attrs
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -27,6 +27,7 @@ setuptools.setup(
|
||||||
"alembic>=1.0.0,<2",
|
"alembic>=1.0.0,<2",
|
||||||
"commonmark>=0.8.1,<1",
|
"commonmark>=0.8.1,<1",
|
||||||
"ruamel.yaml>=0.15.35,<0.16",
|
"ruamel.yaml>=0.15.35,<0.16",
|
||||||
|
"attrs>=18.2.0,<19",
|
||||||
],
|
],
|
||||||
|
|
||||||
classifiers=[
|
classifiers=[
|
||||||
|
|
Loading…
Reference in a new issue