Make non-regex matching faster

This commit is contained in:
Tulir Asokan 2019-06-23 03:08:03 +03:00
parent 2c54aa395a
commit 3992db4464

View file

@ -13,7 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import NewType, Optional, Pattern, Match, Union, Dict, List, Tuple, Set, Type, Any
from typing import (NewType, Optional, Callable, Pattern, Match, Union, Dict, List, Tuple, Set,
Type, Any)
from itertools import chain
import copy
import re
@ -34,10 +35,36 @@ class Config(BaseProxyConfig):
helper.copy("templates")
class ConfigError(Exception):
pass
class Key(str):
pass
class BlankMatch:
@staticmethod
def groups() -> List[str]:
return []
class SimplePattern:
_ptm = BlankMatch()
matcher: Callable[[str], bool]
def __init__(self, matcher: Callable[[str], bool]) -> None:
self.matcher = matcher
def match(self, val: str) -> BlankMatch:
if self.matcher(val):
return self._ptm
RMatch = Union[Match, BlankMatch]
RPattern = Union[Pattern, SimplePattern]
Index = NewType("Index", Union[str, int, Key])
variable_regex = re.compile(r"\$\${([0-9A-Za-z-_]+)}")
@ -102,7 +129,8 @@ class Template:
@dataclass
class Rule:
rooms: Set[RoomID]
matches: List[Pattern]
matches: List[RPattern]
not_matches: List[RPattern]
template: Template
type: Optional[EventType]
variables: Dict[str, JinjaTemplate]
@ -123,10 +151,6 @@ class Rule:
await evt.client.send_message_event(evt.room_id, self.type or self.template.type, content)
class ConfigError(Exception):
pass
class ReactBot(Plugin):
rules: Dict[str, Rule]
templates: Dict[str, Template]
@ -154,10 +178,30 @@ class ReactBot(Plugin):
except Exception as e:
raise ConfigError(f"Failed to load {name}") from e
@staticmethod
def _compile(pattern: str) -> RPattern:
esc = re.escape(pattern)
if esc == pattern:
return SimplePattern(lambda val: pattern in val)
elif pattern[0] == '^' and esc == f"\\^{pattern}":
pattern = pattern[1:]
return SimplePattern(lambda val: val.startswith(pattern))
elif pattern[-1] == '$' and esc == f"{pattern}\\$":
pattern = pattern[:-1]
return SimplePattern(lambda val: val.endswith(pattern))
elif pattern[0] == '^' and pattern[-1] == '$' and esc == f"\\^{pattern}\\$":
pattern = pattern[1:-1]
return SimplePattern(lambda val: val == pattern)
return re.compile(pattern)
def _compile_all(self, patterns: List[str]) -> List[RPattern]:
return [self._compile(pattern) for pattern in patterns]
def _make_rule(self, name: str, rule: Dict[str, Any]) -> Rule:
try:
return Rule(rooms=set(rule.get("rooms", [])),
matches=[re.compile(match) for match in rule.get("matches")],
matches=self._compile_all(rule["matches"]),
not_matches=self._compile_all(rule.get("not_matches", [])),
type=EventType.find(rule["type"]) if "type" in rule else None,
template=self.templates[rule["template"]],
variables=self._parse_variables(rule))