Fix CommandHandler descriptor
This commit is contained in:
parent
4ea980cb93
commit
8b5c637f76
2 changed files with 22 additions and 22 deletions
|
@ -14,7 +14,7 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
|
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
|
||||||
Dict, Tuple, Set)
|
Dict, Tuple, Set, Type)
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
@ -55,27 +55,28 @@ class CommandHandler:
|
||||||
self.__mb_event_handler__: bool = True
|
self.__mb_event_handler__: bool = True
|
||||||
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
|
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
|
||||||
self.__mb_msgtypes__: List[MessageType] = (MessageType.TEXT,)
|
self.__mb_msgtypes__: List[MessageType] = (MessageType.TEXT,)
|
||||||
self.__instance_vars: Dict[str, CommandHandler] = {}
|
self.__bound_copies__: Dict[Any, CommandHandler] = {}
|
||||||
self.__class_instance: Any = None
|
self.__bound_instance__: Any = None
|
||||||
|
|
||||||
def __copy__(self) -> 'CommandHandler':
|
def __get__(self, instance, instancetype):
|
||||||
|
if not instance or self.__bound_instance__:
|
||||||
|
return self
|
||||||
|
try:
|
||||||
|
return self.__bound_copies__[instance]
|
||||||
|
except KeyError:
|
||||||
new_ch = type(self)(self.__mb_func__)
|
new_ch = type(self)(self.__mb_func__)
|
||||||
keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match",
|
keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match",
|
||||||
"require_subcommand", "arg_fallthrough", "event_handler", "event_type", "msgtypes"]
|
"require_subcommand", "arg_fallthrough", "event_handler", "event_type",
|
||||||
|
"msgtypes"]
|
||||||
for key in keys:
|
for key in keys:
|
||||||
key = f"__mb_{key}__"
|
key = f"__mb_{key}__"
|
||||||
setattr(new_ch, key, getattr(self, key))
|
setattr(new_ch, key, getattr(self, key))
|
||||||
|
new_ch.__bound_instance__ = instance
|
||||||
|
new_ch.__mb_subcommands__ = [subcmd.__get__(instance, instancetype)
|
||||||
|
for subcmd in self.__mb_subcommands__]
|
||||||
|
self.__bound_copies__[instance] = new_ch
|
||||||
return new_ch
|
return new_ch
|
||||||
|
|
||||||
def __get__(self, instance, instancetype):
|
|
||||||
try:
|
|
||||||
return self.__instance_vars[instance]
|
|
||||||
except KeyError:
|
|
||||||
copy = self.__copy__()
|
|
||||||
copy.__class_instance = instance
|
|
||||||
self.__instance_vars[instance] = copy
|
|
||||||
return copy
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __command_match_unset(self, val: str) -> str:
|
def __command_match_unset(self, val: str) -> str:
|
||||||
raise NotImplementedError("Hmm")
|
raise NotImplementedError("Hmm")
|
||||||
|
@ -108,15 +109,15 @@ class CommandHandler:
|
||||||
await evt.reply(self.__mb_full_help__)
|
await evt.reply(self.__mb_full_help__)
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.__class_instance:
|
if self.__bound_instance__:
|
||||||
return await self.__mb_func__(self.__class_instance, evt, **call_args)
|
return await self.__mb_func__(self.__bound_instance__, evt, **call_args)
|
||||||
return await self.__mb_func__(evt, **call_args)
|
return await self.__mb_func__(evt, **call_args)
|
||||||
|
|
||||||
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
|
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
|
||||||
remaining_val: str) -> Tuple[bool, Any]:
|
remaining_val: str) -> Tuple[bool, Any]:
|
||||||
command, remaining_val = _split_in_two(remaining_val.strip(), " ")
|
command, remaining_val = _split_in_two(remaining_val.strip(), " ")
|
||||||
for subcommand in self.__mb_subcommands__:
|
for subcommand in self.__mb_subcommands__:
|
||||||
if subcommand.__mb_is_command_match__(subcommand.__class_instance, command):
|
if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command):
|
||||||
return True, await subcommand(evt, _existing_args=call_args,
|
return True, await subcommand(evt, _existing_args=call_args,
|
||||||
remaining_val=remaining_val)
|
remaining_val=remaining_val)
|
||||||
return False, None
|
return False, None
|
||||||
|
@ -156,7 +157,7 @@ class CommandHandler:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def __mb_name__(self) -> str:
|
def __mb_name__(self) -> str:
|
||||||
return self.__mb_get_name__(self.__class_instance)
|
return self.__mb_get_name__(self.__bound_instance__)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def __mb_prefix__(self) -> str:
|
def __mb_prefix__(self) -> str:
|
||||||
|
|
|
@ -23,7 +23,6 @@ from sqlalchemy.engine.base import Engine
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from mautrix.types import Event
|
|
||||||
from mautrix.util.config import BaseProxyConfig
|
from mautrix.util.config import BaseProxyConfig
|
||||||
from .client import MaubotMatrixClient
|
from .client import MaubotMatrixClient
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue