Add support for capturing the match in simple matches
This commit is contained in:
parent
16e4b8e6d8
commit
3964aa6f12
1 changed files with 46 additions and 23 deletions
|
@ -1,5 +1,5 @@
|
|||
# reminder - A maubot plugin that reacts to messages that match predefined rules.
|
||||
# Copyright (C) 2019 Tulir Asokan
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -13,39 +13,59 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Callable, List, Dict, Optional
|
||||
from typing import Callable, List, Dict, Optional, NamedTuple
|
||||
import re
|
||||
|
||||
|
||||
class BlankMatch:
|
||||
@staticmethod
|
||||
def groups() -> List[str]:
|
||||
return []
|
||||
class SimpleMatch(NamedTuple):
|
||||
value: str
|
||||
|
||||
@staticmethod
|
||||
def group(group: int) -> str:
|
||||
return ""
|
||||
def groups(self) -> List[str]:
|
||||
return [self.value]
|
||||
|
||||
@staticmethod
|
||||
def groupdict() -> Dict[str, str]:
|
||||
def group(self, group: int) -> Optional[str]:
|
||||
if group == 0:
|
||||
return self.value
|
||||
return None
|
||||
|
||||
def groupdict(self) -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
class SimplePattern:
|
||||
_ptm = BlankMatch()
|
||||
def matcher_equals(val: str, pattern: str) -> bool:
|
||||
return val == pattern
|
||||
|
||||
matcher: Callable[[str], bool]
|
||||
|
||||
def matcher_startswith(val: str, pattern: str) -> bool:
|
||||
return val.startswith(pattern)
|
||||
|
||||
|
||||
def matcher_endswith(val: str, pattern: str) -> bool:
|
||||
return val.endswith(pattern)
|
||||
|
||||
|
||||
def matcher_contains(val: str, pattern: str) -> bool:
|
||||
return pattern in val
|
||||
|
||||
|
||||
SimpleMatcherFunc = Callable[[str, str], bool]
|
||||
|
||||
|
||||
class SimplePattern:
|
||||
matcher: SimpleMatcherFunc
|
||||
pattern: str
|
||||
ignorecase: bool
|
||||
|
||||
def __init__(self, matcher: Callable[[str], bool], ignorecase: bool) -> None:
|
||||
def __init__(self, matcher: SimpleMatcherFunc, pattern: str, ignorecase: bool) -> None:
|
||||
self.matcher = matcher
|
||||
self.pattern = pattern
|
||||
self.ignorecase = ignorecase
|
||||
|
||||
def search(self, val: str) -> BlankMatch:
|
||||
def search(self, val: str) -> SimpleMatch:
|
||||
if self.ignorecase:
|
||||
val = val.lower()
|
||||
if self.matcher(val):
|
||||
return self._ptm
|
||||
if self.matcher(val, self.pattern):
|
||||
return SimpleMatch(self.pattern)
|
||||
|
||||
@staticmethod
|
||||
def compile(pattern: str, flags: re.RegexFlag = re.RegexFlag(0), force_raw: bool = False
|
||||
|
@ -58,13 +78,16 @@ class SimplePattern:
|
|||
first, last = pattern[0], pattern[-1]
|
||||
if first == '^' and last == '$' and (force_raw or esc == f"\\^{pattern[1:-1]}\\$"):
|
||||
s_pattern = s_pattern[1:-1]
|
||||
return SimplePattern(lambda val: val == s_pattern, ignorecase=ignorecase)
|
||||
func = matcher_equals
|
||||
elif first == '^' and (force_raw or esc == f"\\^{pattern[1:]}"):
|
||||
s_pattern = s_pattern[1:]
|
||||
return SimplePattern(lambda val: val.startswith(s_pattern), ignorecase=ignorecase)
|
||||
func = matcher_startswith
|
||||
elif last == '$' and (force_raw or esc == f"{pattern[:-1]}\\$"):
|
||||
s_pattern = s_pattern[:-1]
|
||||
return SimplePattern(lambda val: val.endswith(s_pattern), ignorecase=ignorecase)
|
||||
func = matcher_endswith
|
||||
elif force_raw or esc == pattern:
|
||||
return SimplePattern(lambda val: s_pattern in val, ignorecase=ignorecase)
|
||||
return None
|
||||
func = matcher_contains
|
||||
else:
|
||||
# Not a simple pattern
|
||||
return None
|
||||
return SimplePattern(matcher=func, pattern=s_pattern, ignorecase=ignorecase)
|
||||
|
|
Loading…
Reference in a new issue