Fix CommandHandler descriptor
This commit is contained in:
parent
4ea980cb93
commit
8b5c637f76
2 changed files with 22 additions and 22 deletions
|
@ -14,7 +14,7 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
|
||||
Dict, Tuple, Set)
|
||||
Dict, Tuple, Set, Type)
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
import functools
|
||||
|
@ -55,26 +55,27 @@ class CommandHandler:
|
|||
self.__mb_event_handler__: bool = True
|
||||
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
|
||||
self.__mb_msgtypes__: List[MessageType] = (MessageType.TEXT,)
|
||||
self.__instance_vars: Dict[str, CommandHandler] = {}
|
||||
self.__class_instance: Any = None
|
||||
|
||||
def __copy__(self) -> 'CommandHandler':
|
||||
new_ch = type(self)(self.__mb_func__)
|
||||
keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match",
|
||||
"require_subcommand", "arg_fallthrough", "event_handler", "event_type", "msgtypes"]
|
||||
for key in keys:
|
||||
key = f"__mb_{key}__"
|
||||
setattr(new_ch, key, getattr(self, key))
|
||||
return new_ch
|
||||
self.__bound_copies__: Dict[Any, CommandHandler] = {}
|
||||
self.__bound_instance__: Any = None
|
||||
|
||||
def __get__(self, instance, instancetype):
|
||||
if not instance or self.__bound_instance__:
|
||||
return self
|
||||
try:
|
||||
return self.__instance_vars[instance]
|
||||
return self.__bound_copies__[instance]
|
||||
except KeyError:
|
||||
copy = self.__copy__()
|
||||
copy.__class_instance = instance
|
||||
self.__instance_vars[instance] = copy
|
||||
return copy
|
||||
new_ch = type(self)(self.__mb_func__)
|
||||
keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match",
|
||||
"require_subcommand", "arg_fallthrough", "event_handler", "event_type",
|
||||
"msgtypes"]
|
||||
for key in keys:
|
||||
key = f"__mb_{key}__"
|
||||
setattr(new_ch, key, getattr(self, key))
|
||||
new_ch.__bound_instance__ = instance
|
||||
new_ch.__mb_subcommands__ = [subcmd.__get__(instance, instancetype)
|
||||
for subcmd in self.__mb_subcommands__]
|
||||
self.__bound_copies__[instance] = new_ch
|
||||
return new_ch
|
||||
|
||||
@staticmethod
|
||||
def __command_match_unset(self, val: str) -> str:
|
||||
|
@ -108,15 +109,15 @@ class CommandHandler:
|
|||
await evt.reply(self.__mb_full_help__)
|
||||
return
|
||||
|
||||
if self.__class_instance:
|
||||
return await self.__mb_func__(self.__class_instance, evt, **call_args)
|
||||
if self.__bound_instance__:
|
||||
return await self.__mb_func__(self.__bound_instance__, evt, **call_args)
|
||||
return await self.__mb_func__(evt, **call_args)
|
||||
|
||||
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
|
||||
remaining_val: str) -> Tuple[bool, Any]:
|
||||
command, remaining_val = _split_in_two(remaining_val.strip(), " ")
|
||||
for subcommand in self.__mb_subcommands__:
|
||||
if subcommand.__mb_is_command_match__(subcommand.__class_instance, command):
|
||||
if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command):
|
||||
return True, await subcommand(evt, _existing_args=call_args,
|
||||
remaining_val=remaining_val)
|
||||
return False, None
|
||||
|
@ -156,7 +157,7 @@ class CommandHandler:
|
|||
|
||||
@property
|
||||
def __mb_name__(self) -> str:
|
||||
return self.__mb_get_name__(self.__class_instance)
|
||||
return self.__mb_get_name__(self.__bound_instance__)
|
||||
|
||||
@property
|
||||
def __mb_prefix__(self) -> str:
|
||||
|
|
|
@ -23,7 +23,6 @@ from sqlalchemy.engine.base import Engine
|
|||
from aiohttp import ClientSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mautrix.types import Event
|
||||
from mautrix.util.config import BaseProxyConfig
|
||||
from .client import MaubotMatrixClient
|
||||
|
||||
|
|
Loading…
Reference in a new issue