Add command matching stuff

This commit is contained in:
Tulir Asokan 2018-10-16 00:25:23 +03:00
parent c79ed97a47
commit 0b246e44a8
7 changed files with 196 additions and 34 deletions

View file

@ -1,3 +1,4 @@
from .plugin_base import Plugin
from .command_spec import CommandSpec, Command, PassiveCommand, Argument
from .event import FakeEvent as Event
from .client import MaubotMatrixClient as Client

View file

@ -13,33 +13,80 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Optional
from typing import Dict, List, Optional, Union, Callable
from aiohttp import ClientSession
import asyncio
import logging
from mautrix import Client as MatrixClient
from mautrix.client import EventHandler
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
EventType)
EventType, MessageEvent)
from .command_spec import ParsedCommand
from .db import DBClient
log = logging.getLogger("maubot.client")
class MaubotMatrixClient(MatrixClient):
def __init__(self, maubot_client: 'Client', *args, **kwargs):
super().__init__(*args, **kwargs)
self._maubot_client = maubot_client
self.command_handlers: Dict[str, List[EventHandler]] = {}
self.commands: List[ParsedCommand] = []
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
async def _command_event_handler(self, evt: MessageEvent) -> None:
for command in self.commands:
if command.match(evt):
await self._trigger_command(command, evt)
return
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
for handler in self.command_handlers.get(command.name, []):
await handler(evt)
def on(self, var: Union[EventHandler, EventType, str]
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
if isinstance(var, str):
def decorator(func: EventHandler) -> EventHandler:
self.add_command_handler(var, func)
return func
return decorator
return super().on(var)
def add_command_handler(self, command: str, handler: EventHandler) -> None:
self.command_handlers.setdefault(command, []).append(handler)
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
try:
self.command_handlers[command].remove(handler)
except (KeyError, ValueError):
pass
class Client:
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
db_instance: DBClient
client: MaubotMatrixClient
def __init__(self, db_instance: DBClient) -> None:
self.db_instance: DBClient = db_instance
self.db_instance = db_instance
self.cache[self.id] = self
self.client: MatrixClient = MatrixClient(mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=log.getChild(self.id))
self.client = MaubotMatrixClient(maubot_client=self,
store=self.db_instance,
mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=log.getChild(self.id))
if self.autojoin:
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
@classmethod
def get(cls, user_id: UserID) -> Optional['Client']:
@ -103,9 +150,9 @@ class Client:
if value == self.db_instance.autojoin:
return
if value:
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
else:
self.client.remove_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
self.db_instance.autojoin = value
@property
@ -126,6 +173,6 @@ class Client:
# endregion
async def handle_invite(self, evt: StateEvent) -> None:
async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)

View file

@ -13,10 +13,11 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict
from typing import List, Dict, Pattern, Union, Tuple, Optional, Any
from attr import dataclass
import re
from mautrix.types import Event
from mautrix.types import MessageEvent, MatchedCommand, MatchedPassiveCommand
from mautrix.client.api.types.util import SerializableAttrs
@ -39,7 +40,103 @@ class PassiveCommand(SerializableAttrs['PassiveCommand']):
name: str
matches: str
match_against: str
match_event: Event = None
match_event: MessageEvent = None
class ParsedCommand:
name: str
is_passive: bool
arguments: List[str]
starts_with: str
matches: Pattern
match_against: str
match_event: MessageEvent
def __init__(self, command: Union[PassiveCommand, Command]) -> None:
if isinstance(command, PassiveCommand):
self._init_passive(command)
elif isinstance(command, Command):
self._init_active(command)
else:
raise ValueError("Command parameter must be a Command or a PassiveCommand.")
def _init_passive(self, command: PassiveCommand) -> None:
self.name = command.name
self.is_passive = True
self.match_against = command.match_against
self.matches = re.compile(command.matches)
self.match_event = command.match_event
def _init_active(self, command: Command) -> None:
self.name = command.syntax
self.is_passive = False
regex_builder = []
sw_builder = []
argument_encountered = False
for word in command.syntax.split(" "):
arg = command.arguments.get(word, None)
if arg is not None and len(word) > 0:
argument_encountered = True
regex = f"({arg.matches})" if arg.required else f"(?:{arg.matches})?"
self.arguments.append(word)
regex_builder.append(regex)
else:
if not argument_encountered:
sw_builder.append(word)
regex_builder.append(re.escape(word))
self.starts_with = "!" + " ".join(sw_builder)
self.matches = re.compile("^!" + " ".join(regex_builder) + "$")
self.match_against = "body"
def match(self, evt: MessageEvent) -> bool:
return self._match_passive(evt) if self.is_passive else self._match_active(evt)
@staticmethod
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
if '.' not in key:
return key, None
key, next_key = key.split('.', 1)
if len(key) > 0 and key[0] == "[":
end_index = next_key.index("]")
key = key[1:] + "." + next_key[:end_index]
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
return key, next_key
@classmethod
def _recursive_get(cls, data: Any, key: str) -> Any:
if not data:
return None
key, next_key = cls._parse_key(key)
if next_key is not None:
return cls._recursive_get(data[key], next_key)
return data[key]
def _match_passive(self, evt: MessageEvent) -> bool:
try:
match_against = self._recursive_get(evt.content, self.match_against)
except KeyError:
match_against = None
match_against = match_against or evt.content.body
matches = [[match.string[match.start():match.end()]] + list(match.groups())
for match in self.matches.finditer(match_against)]
if not matches:
return False
if evt.unsigned.passive_command is None:
evt.unsigned.passive_command = {}
evt.unsigned.passive_command[self.name] = MatchedPassiveCommand(captured=matches)
return True
def _match_active(self, evt: MessageEvent) -> bool:
if not evt.content.body.startswith(self.starts_with):
return False
match = self.matches.match(evt.content.body)
if not match:
return False
evt.content.command = MatchedCommand(matched=self.name,
arguments=dict(zip(self.arguments, match.groups())))
return True
@dataclass
@ -50,3 +147,6 @@ class CommandSpec(SerializableAttrs['CommandSpec']):
def __add__(self, other: 'CommandSpec') -> 'CommandSpec':
return CommandSpec(commands=self.commands + other.commands,
passive_commands=self.passive_commands + other.passive_commands)
def parse(self) -> List[ParsedCommand]:
return [ParsedCommand(command) for command in self.commands + self.passive_commands]

View file

@ -13,31 +13,39 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
from sqlalchemy.orm import Query
from sqlalchemy.ext.declarative import declarative_base
import json
from mautrix.types import JSON, UserID, FilterID, SyncToken, ContentURI
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
from mautrix.client.api.types.util import Serializable
from mautrix import ClientStore
from .command_spec import CommandSpec
Base: declarative_base = declarative_base()
class JSONEncodedDict(TypeDecorator):
impl = Text
def make_serializable_alchemy(serializable_type: Type[Serializable]):
class SerializableAlchemy(TypeDecorator):
impl = Text
@property
def python_type(self):
return dict
@property
def python_type(self):
return serializable_type
def process_literal_param(self, value, _):
return json.dumps(value) if value is not None else None
def process_literal_param(self, value: Serializable, _) -> str:
return json.dumps(value.serialize()) if value is not None else None
def process_bind_param(self, value, _):
return json.dumps(value) if value is not None else None
def process_bind_param(self, value: Serializable, _) -> str:
return json.dumps(value.serialize()) if value is not None else None
def process_result_value(self, value, _):
return json.loads(value) if value is not None else None
def process_result_value(self, value: str, _) -> serializable_type:
return serializable_type.deserialize(json.loads(value)) if value is not None else None
return SerializableAlchemy
class DBPlugin(Base):
@ -52,7 +60,7 @@ class DBPlugin(Base):
nullable=False)
class DBClient(Base):
class DBClient(ClientStore, Base):
query: Query
__tablename__ = "client"
@ -74,10 +82,10 @@ class DBCommandSpec(Base):
query: Query
__tablename__ = "command_spec"
owner: str = Column(String(255),
ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True)
plugin: str = Column(String(255),
ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True)
client: UserID = Column(String(255),
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True)
spec: JSON = Column(JSONEncodedDict, nullable=False)
spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False)

View file

@ -17,13 +17,17 @@ from typing import TYPE_CHECKING
from abc import ABC
if TYPE_CHECKING:
from mautrix import Client as MatrixClient
from .client import MaubotMatrixClient
from .command_spec import CommandSpec
class Plugin(ABC):
def __init__(self, client: 'MatrixClient') -> None:
def __init__(self, client: 'MaubotMatrixClient') -> None:
self.client = client
def set_command_spec(self, spec: 'CommandSpec') -> None:
pass
async def start(self) -> None:
pass

View file

@ -4,3 +4,4 @@ SQLAlchemy
alembic
commonmark
ruamel.yaml
attrs

View file

@ -27,6 +27,7 @@ setuptools.setup(
"alembic>=1.0.0,<2",
"commonmark>=0.8.1,<1",
"ruamel.yaml>=0.15.35,<0.16",
"attrs>=18.2.0,<19",
],
classifiers=[