Add command matching stuff

This commit is contained in:
Tulir Asokan 2018-10-16 00:25:23 +03:00
parent c79ed97a47
commit 0b246e44a8
7 changed files with 196 additions and 34 deletions

View file

@ -1,3 +1,4 @@
from .plugin_base import Plugin from .plugin_base import Plugin
from .command_spec import CommandSpec, Command, PassiveCommand, Argument from .command_spec import CommandSpec, Command, PassiveCommand, Argument
from .event import FakeEvent as Event from .event import FakeEvent as Event
from .client import MaubotMatrixClient as Client

View file

@ -13,33 +13,80 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Optional from typing import Dict, List, Optional, Union, Callable
from aiohttp import ClientSession from aiohttp import ClientSession
import asyncio
import logging import logging
from mautrix import Client as MatrixClient from mautrix import Client as MatrixClient
from mautrix.client import EventHandler
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership, from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
EventType) EventType, MessageEvent)
from .command_spec import ParsedCommand
from .db import DBClient from .db import DBClient
log = logging.getLogger("maubot.client") log = logging.getLogger("maubot.client")
class MaubotMatrixClient(MatrixClient):
def __init__(self, maubot_client: 'Client', *args, **kwargs):
super().__init__(*args, **kwargs)
self._maubot_client = maubot_client
self.command_handlers: Dict[str, List[EventHandler]] = {}
self.commands: List[ParsedCommand] = []
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
async def _command_event_handler(self, evt: MessageEvent) -> None:
for command in self.commands:
if command.match(evt):
await self._trigger_command(command, evt)
return
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
for handler in self.command_handlers.get(command.name, []):
await handler(evt)
def on(self, var: Union[EventHandler, EventType, str]
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
if isinstance(var, str):
def decorator(func: EventHandler) -> EventHandler:
self.add_command_handler(var, func)
return func
return decorator
return super().on(var)
def add_command_handler(self, command: str, handler: EventHandler) -> None:
self.command_handlers.setdefault(command, []).append(handler)
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
try:
self.command_handlers[command].remove(handler)
except (KeyError, ValueError):
pass
class Client: class Client:
cache: Dict[UserID, 'Client'] = {} cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None http_client: ClientSession = None
db_instance: DBClient
client: MaubotMatrixClient
def __init__(self, db_instance: DBClient) -> None: def __init__(self, db_instance: DBClient) -> None:
self.db_instance: DBClient = db_instance self.db_instance = db_instance
self.cache[self.id] = self self.cache[self.id] = self
self.client: MatrixClient = MatrixClient(mxid=self.id, self.client = MaubotMatrixClient(maubot_client=self,
base_url=self.homeserver, store=self.db_instance,
token=self.access_token, mxid=self.id,
client_session=self.http_client, base_url=self.homeserver,
log=log.getChild(self.id)) token=self.access_token,
client_session=self.http_client,
log=log.getChild(self.id))
if self.autojoin: if self.autojoin:
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER) self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
@classmethod @classmethod
def get(cls, user_id: UserID) -> Optional['Client']: def get(cls, user_id: UserID) -> Optional['Client']:
@ -103,9 +150,9 @@ class Client:
if value == self.db_instance.autojoin: if value == self.db_instance.autojoin:
return return
if value: if value:
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER) self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
else: else:
self.client.remove_event_handler(self.handle_invite, EventType.ROOM_MEMBER) self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
self.db_instance.autojoin = value self.db_instance.autojoin = value
@property @property
@ -126,6 +173,6 @@ class Client:
# endregion # endregion
async def handle_invite(self, evt: StateEvent) -> None: async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE: if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id) await self.client.join_room_by_id(evt.room_id)

View file

@ -13,10 +13,11 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict from typing import List, Dict, Pattern, Union, Tuple, Optional, Any
from attr import dataclass from attr import dataclass
import re
from mautrix.types import Event from mautrix.types import MessageEvent, MatchedCommand, MatchedPassiveCommand
from mautrix.client.api.types.util import SerializableAttrs from mautrix.client.api.types.util import SerializableAttrs
@ -39,7 +40,103 @@ class PassiveCommand(SerializableAttrs['PassiveCommand']):
name: str name: str
matches: str matches: str
match_against: str match_against: str
match_event: Event = None match_event: MessageEvent = None
class ParsedCommand:
name: str
is_passive: bool
arguments: List[str]
starts_with: str
matches: Pattern
match_against: str
match_event: MessageEvent
def __init__(self, command: Union[PassiveCommand, Command]) -> None:
if isinstance(command, PassiveCommand):
self._init_passive(command)
elif isinstance(command, Command):
self._init_active(command)
else:
raise ValueError("Command parameter must be a Command or a PassiveCommand.")
def _init_passive(self, command: PassiveCommand) -> None:
self.name = command.name
self.is_passive = True
self.match_against = command.match_against
self.matches = re.compile(command.matches)
self.match_event = command.match_event
def _init_active(self, command: Command) -> None:
self.name = command.syntax
self.is_passive = False
regex_builder = []
sw_builder = []
argument_encountered = False
for word in command.syntax.split(" "):
arg = command.arguments.get(word, None)
if arg is not None and len(word) > 0:
argument_encountered = True
regex = f"({arg.matches})" if arg.required else f"(?:{arg.matches})?"
self.arguments.append(word)
regex_builder.append(regex)
else:
if not argument_encountered:
sw_builder.append(word)
regex_builder.append(re.escape(word))
self.starts_with = "!" + " ".join(sw_builder)
self.matches = re.compile("^!" + " ".join(regex_builder) + "$")
self.match_against = "body"
def match(self, evt: MessageEvent) -> bool:
return self._match_passive(evt) if self.is_passive else self._match_active(evt)
@staticmethod
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
if '.' not in key:
return key, None
key, next_key = key.split('.', 1)
if len(key) > 0 and key[0] == "[":
end_index = next_key.index("]")
key = key[1:] + "." + next_key[:end_index]
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
return key, next_key
@classmethod
def _recursive_get(cls, data: Any, key: str) -> Any:
if not data:
return None
key, next_key = cls._parse_key(key)
if next_key is not None:
return cls._recursive_get(data[key], next_key)
return data[key]
def _match_passive(self, evt: MessageEvent) -> bool:
try:
match_against = self._recursive_get(evt.content, self.match_against)
except KeyError:
match_against = None
match_against = match_against or evt.content.body
matches = [[match.string[match.start():match.end()]] + list(match.groups())
for match in self.matches.finditer(match_against)]
if not matches:
return False
if evt.unsigned.passive_command is None:
evt.unsigned.passive_command = {}
evt.unsigned.passive_command[self.name] = MatchedPassiveCommand(captured=matches)
return True
def _match_active(self, evt: MessageEvent) -> bool:
if not evt.content.body.startswith(self.starts_with):
return False
match = self.matches.match(evt.content.body)
if not match:
return False
evt.content.command = MatchedCommand(matched=self.name,
arguments=dict(zip(self.arguments, match.groups())))
return True
@dataclass @dataclass
@ -50,3 +147,6 @@ class CommandSpec(SerializableAttrs['CommandSpec']):
def __add__(self, other: 'CommandSpec') -> 'CommandSpec': def __add__(self, other: 'CommandSpec') -> 'CommandSpec':
return CommandSpec(commands=self.commands + other.commands, return CommandSpec(commands=self.commands + other.commands,
passive_commands=self.passive_commands + other.passive_commands) passive_commands=self.passive_commands + other.passive_commands)
def parse(self) -> List[ParsedCommand]:
return [ParsedCommand(command) for command in self.commands + self.passive_commands]

View file

@ -13,31 +13,39 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator) from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
from sqlalchemy.orm import Query from sqlalchemy.orm import Query
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
import json import json
from mautrix.types import JSON, UserID, FilterID, SyncToken, ContentURI from mautrix.types import UserID, FilterID, SyncToken, ContentURI
from mautrix.client.api.types.util import Serializable
from mautrix import ClientStore
from .command_spec import CommandSpec
Base: declarative_base = declarative_base() Base: declarative_base = declarative_base()
class JSONEncodedDict(TypeDecorator): def make_serializable_alchemy(serializable_type: Type[Serializable]):
impl = Text class SerializableAlchemy(TypeDecorator):
impl = Text
@property @property
def python_type(self): def python_type(self):
return dict return serializable_type
def process_literal_param(self, value, _): def process_literal_param(self, value: Serializable, _) -> str:
return json.dumps(value) if value is not None else None return json.dumps(value.serialize()) if value is not None else None
def process_bind_param(self, value, _): def process_bind_param(self, value: Serializable, _) -> str:
return json.dumps(value) if value is not None else None return json.dumps(value.serialize()) if value is not None else None
def process_result_value(self, value, _): def process_result_value(self, value: str, _) -> serializable_type:
return json.loads(value) if value is not None else None return serializable_type.deserialize(json.loads(value)) if value is not None else None
return SerializableAlchemy
class DBPlugin(Base): class DBPlugin(Base):
@ -52,7 +60,7 @@ class DBPlugin(Base):
nullable=False) nullable=False)
class DBClient(Base): class DBClient(ClientStore, Base):
query: Query query: Query
__tablename__ = "client" __tablename__ = "client"
@ -74,10 +82,10 @@ class DBCommandSpec(Base):
query: Query query: Query
__tablename__ = "command_spec" __tablename__ = "command_spec"
owner: str = Column(String(255), plugin: str = Column(String(255),
ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"), ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) primary_key=True)
client: UserID = Column(String(255), client: UserID = Column(String(255),
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"), ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True) primary_key=True)
spec: JSON = Column(JSONEncodedDict, nullable=False) spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False)

View file

@ -17,13 +17,17 @@ from typing import TYPE_CHECKING
from abc import ABC from abc import ABC
if TYPE_CHECKING: if TYPE_CHECKING:
from mautrix import Client as MatrixClient from .client import MaubotMatrixClient
from .command_spec import CommandSpec
class Plugin(ABC): class Plugin(ABC):
def __init__(self, client: 'MatrixClient') -> None: def __init__(self, client: 'MaubotMatrixClient') -> None:
self.client = client self.client = client
def set_command_spec(self, spec: 'CommandSpec') -> None:
pass
async def start(self) -> None: async def start(self) -> None:
pass pass

View file

@ -4,3 +4,4 @@ SQLAlchemy
alembic alembic
commonmark commonmark
ruamel.yaml ruamel.yaml
attrs

View file

@ -27,6 +27,7 @@ setuptools.setup(
"alembic>=1.0.0,<2", "alembic>=1.0.0,<2",
"commonmark>=0.8.1,<1", "commonmark>=0.8.1,<1",
"ruamel.yaml>=0.15.35,<0.16", "ruamel.yaml>=0.15.35,<0.16",
"attrs>=18.2.0,<19",
], ],
classifiers=[ classifiers=[