From b213481d7dad5ba4c8ec3d1c05756ac54c016149 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 11 Dec 2020 19:59:19 +0200 Subject: [PATCH] Expose named capture groups and earlier variables in jinja variables (ref #5) --- reactbot/bot.py | 1 - reactbot/config.py | 8 +++++--- reactbot/rule.py | 12 +++++++----- reactbot/simplepattern.py | 6 +++++- reactbot/template.py | 11 +++++++---- 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/reactbot/bot.py b/reactbot/bot.py index 37bf465..b893110 100644 --- a/reactbot/bot.py +++ b/reactbot/bot.py @@ -14,7 +14,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Type, Tuple, Dict -from itertools import chain import time from attr import dataclass diff --git a/reactbot/config.py b/reactbot/config.py index 1f37570..8cb9ba5 100644 --- a/reactbot/config.py +++ b/reactbot/config.py @@ -88,9 +88,11 @@ class Config(BaseProxyConfig): return re.compile(pattern, flags=flags) @staticmethod - def _parse_variables(data: Dict[str, Any]) -> Dict[str, JinjaTemplate]: - return {name: JinjaTemplate(var_tpl) for name, var_tpl - in data.get("variables", {}).items()} + def _parse_variables(data: Dict[str, Any]) -> Dict[str, Any]: + return {name: (JinjaTemplate(var_tpl) + if isinstance(var_tpl, str) and var_tpl.startswith("{{") + else var_tpl) + for name, var_tpl in data.get("variables", {}).items()} @staticmethod def _parse_content(content: Union[Dict[str, Any], str]) -> Union[Dict[str, Any], JinjaTemplate]: diff --git a/reactbot/rule.py b/reactbot/rule.py index 441ffb7..f7703d2 100644 --- a/reactbot/rule.py +++ b/reactbot/rule.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Match, Dict, List, Set, Union, Pattern +from typing import Optional, Match, Dict, List, Set, Union, Pattern, Any from attr import dataclass from jinja2 import Template as JinjaTemplate @@ -36,7 +36,7 @@ class Rule: not_matches: List[RPattern] template: Template type: Optional[EventType] - variables: Dict[str, JinjaTemplate] + variables: Dict[str, Any] def _check_not_match(self, body: str) -> bool: for pattern in self.not_matches: @@ -58,7 +58,9 @@ class Rule: return None async def execute(self, evt: MessageEvent, match: Match) -> None: - content = self.template.execute(evt=evt, rule_vars=self.variables, - extra_vars={str(i): val for i, val in - enumerate(match.groups())}) + extra_vars = { + **{str(i): val for i, val in enumerate(match.groups())}, + **match.groupdict(), + } + content = self.template.execute(evt=evt, rule_vars=self.variables, extra_vars=extra_vars) await evt.client.send_message_event(evt.room_id, self.type or self.template.type, content) diff --git a/reactbot/simplepattern.py b/reactbot/simplepattern.py index 4b30890..d9e74e1 100644 --- a/reactbot/simplepattern.py +++ b/reactbot/simplepattern.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Callable, List, Optional +from typing import Callable, List, Dict, Optional import re @@ -22,6 +22,10 @@ class BlankMatch: def groups() -> List[str]: return [] + @staticmethod + def groupdict() -> Dict[str, str]: + return {} + class SimplePattern: _ptm = BlankMatch() diff --git a/reactbot/template.py b/reactbot/template.py index 09967b9..8b003ac 100644 --- a/reactbot/template.py +++ b/reactbot/template.py @@ -37,7 +37,7 @@ Index = Union[str, int, Key] @dataclass class Template: type: EventType - variables: Dict[str, JinjaTemplate] + variables: Dict[str, Any] content: Union[Dict[str, Any], JinjaTemplate] _variable_locations: List[Tuple[Index, ...]] = None @@ -78,9 +78,12 @@ class Template: def execute(self, evt: Event, rule_vars: Dict[str, JinjaTemplate], extra_vars: Dict[str, str] ) -> Dict[str, Any]: - variables = {**{name: template.render(event=evt) - for name, template in chain(self.variables.items(), rule_vars.items())}, - **extra_vars} + variables = extra_vars + for name, template in chain(rule_vars.items(), self.variables.items()): + if isinstance(template, JinjaTemplate): + variables[name] = template.render(event=evt, variables=variables) + else: + variables[name] = template if isinstance(self.content, JinjaTemplate): raw_json = self.content.render(event=evt, **variables) return json.loads(raw_json)