Make non-regex matching faster

This commit is contained in:
Tulir Asokan 2019-06-23 03:08:03 +03:00
parent 2c54aa395a
commit 3992db4464

View file

@ -13,7 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import NewType, Optional, Pattern, Match, Union, Dict, List, Tuple, Set, Type, Any from typing import (NewType, Optional, Callable, Pattern, Match, Union, Dict, List, Tuple, Set,
Type, Any)
from itertools import chain from itertools import chain
import copy import copy
import re import re
@ -34,10 +35,36 @@ class Config(BaseProxyConfig):
helper.copy("templates") helper.copy("templates")
class ConfigError(Exception):
pass
class Key(str): class Key(str):
pass pass
class BlankMatch:
@staticmethod
def groups() -> List[str]:
return []
class SimplePattern:
_ptm = BlankMatch()
matcher: Callable[[str], bool]
def __init__(self, matcher: Callable[[str], bool]) -> None:
self.matcher = matcher
def match(self, val: str) -> BlankMatch:
if self.matcher(val):
return self._ptm
RMatch = Union[Match, BlankMatch]
RPattern = Union[Pattern, SimplePattern]
Index = NewType("Index", Union[str, int, Key]) Index = NewType("Index", Union[str, int, Key])
variable_regex = re.compile(r"\$\${([0-9A-Za-z-_]+)}") variable_regex = re.compile(r"\$\${([0-9A-Za-z-_]+)}")
@ -102,7 +129,8 @@ class Template:
@dataclass @dataclass
class Rule: class Rule:
rooms: Set[RoomID] rooms: Set[RoomID]
matches: List[Pattern] matches: List[RPattern]
not_matches: List[RPattern]
template: Template template: Template
type: Optional[EventType] type: Optional[EventType]
variables: Dict[str, JinjaTemplate] variables: Dict[str, JinjaTemplate]
@ -123,10 +151,6 @@ class Rule:
await evt.client.send_message_event(evt.room_id, self.type or self.template.type, content) await evt.client.send_message_event(evt.room_id, self.type or self.template.type, content)
class ConfigError(Exception):
pass
class ReactBot(Plugin): class ReactBot(Plugin):
rules: Dict[str, Rule] rules: Dict[str, Rule]
templates: Dict[str, Template] templates: Dict[str, Template]
@ -154,10 +178,30 @@ class ReactBot(Plugin):
except Exception as e: except Exception as e:
raise ConfigError(f"Failed to load {name}") from e raise ConfigError(f"Failed to load {name}") from e
@staticmethod
def _compile(pattern: str) -> RPattern:
esc = re.escape(pattern)
if esc == pattern:
return SimplePattern(lambda val: pattern in val)
elif pattern[0] == '^' and esc == f"\\^{pattern}":
pattern = pattern[1:]
return SimplePattern(lambda val: val.startswith(pattern))
elif pattern[-1] == '$' and esc == f"{pattern}\\$":
pattern = pattern[:-1]
return SimplePattern(lambda val: val.endswith(pattern))
elif pattern[0] == '^' and pattern[-1] == '$' and esc == f"\\^{pattern}\\$":
pattern = pattern[1:-1]
return SimplePattern(lambda val: val == pattern)
return re.compile(pattern)
def _compile_all(self, patterns: List[str]) -> List[RPattern]:
return [self._compile(pattern) for pattern in patterns]
def _make_rule(self, name: str, rule: Dict[str, Any]) -> Rule: def _make_rule(self, name: str, rule: Dict[str, Any]) -> Rule:
try: try:
return Rule(rooms=set(rule.get("rooms", [])), return Rule(rooms=set(rule.get("rooms", [])),
matches=[re.compile(match) for match in rule.get("matches")], matches=self._compile_all(rule["matches"]),
not_matches=self._compile_all(rule.get("not_matches", [])),
type=EventType.find(rule["type"]) if "type" in rule else None, type=EventType.find(rule["type"]) if "type" in rule else None,
template=self.templates[rule["template"]], template=self.templates[rule["template"]],
variables=self._parse_variables(rule)) variables=self._parse_variables(rule))