Fix CommandHandler descriptor

This commit is contained in:
Tulir Asokan 2019-01-18 22:58:43 +02:00
parent 4ea980cb93
commit 8b5c637f76
2 changed files with 22 additions and 22 deletions

View file

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
Dict, Tuple, Set)
Dict, Tuple, Set, Type)
from abc import ABC, abstractmethod
import asyncio
import functools
@ -55,26 +55,27 @@ class CommandHandler:
self.__mb_event_handler__: bool = True
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
self.__mb_msgtypes__: List[MessageType] = (MessageType.TEXT,)
self.__instance_vars: Dict[str, CommandHandler] = {}
self.__class_instance: Any = None
def __copy__(self) -> 'CommandHandler':
new_ch = type(self)(self.__mb_func__)
keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match",
"require_subcommand", "arg_fallthrough", "event_handler", "event_type", "msgtypes"]
for key in keys:
key = f"__mb_{key}__"
setattr(new_ch, key, getattr(self, key))
return new_ch
self.__bound_copies__: Dict[Any, CommandHandler] = {}
self.__bound_instance__: Any = None
def __get__(self, instance, instancetype):
if not instance or self.__bound_instance__:
return self
try:
return self.__instance_vars[instance]
return self.__bound_copies__[instance]
except KeyError:
copy = self.__copy__()
copy.__class_instance = instance
self.__instance_vars[instance] = copy
return copy
new_ch = type(self)(self.__mb_func__)
keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match",
"require_subcommand", "arg_fallthrough", "event_handler", "event_type",
"msgtypes"]
for key in keys:
key = f"__mb_{key}__"
setattr(new_ch, key, getattr(self, key))
new_ch.__bound_instance__ = instance
new_ch.__mb_subcommands__ = [subcmd.__get__(instance, instancetype)
for subcmd in self.__mb_subcommands__]
self.__bound_copies__[instance] = new_ch
return new_ch
@staticmethod
def __command_match_unset(self, val: str) -> str:
@ -108,15 +109,15 @@ class CommandHandler:
await evt.reply(self.__mb_full_help__)
return
if self.__class_instance:
return await self.__mb_func__(self.__class_instance, evt, **call_args)
if self.__bound_instance__:
return await self.__mb_func__(self.__bound_instance__, evt, **call_args)
return await self.__mb_func__(evt, **call_args)
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, Any]:
command, remaining_val = _split_in_two(remaining_val.strip(), " ")
for subcommand in self.__mb_subcommands__:
if subcommand.__mb_is_command_match__(subcommand.__class_instance, command):
if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command):
return True, await subcommand(evt, _existing_args=call_args,
remaining_val=remaining_val)
return False, None
@ -156,7 +157,7 @@ class CommandHandler:
@property
def __mb_name__(self) -> str:
return self.__mb_get_name__(self.__class_instance)
return self.__mb_get_name__(self.__bound_instance__)
@property
def __mb_prefix__(self) -> str:

View file

@ -23,7 +23,6 @@ from sqlalchemy.engine.base import Engine
from aiohttp import ClientSession
if TYPE_CHECKING:
from mautrix.types import Event
from mautrix.util.config import BaseProxyConfig
from .client import MaubotMatrixClient