Expose named capture groups and earlier variables in jinja variables (ref #5)

This commit is contained in:
Tulir Asokan 2020-12-11 19:59:19 +02:00
parent 0790b429b3
commit b213481d7d
5 changed files with 24 additions and 14 deletions

View file

@ -14,7 +14,6 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Tuple, Dict from typing import Type, Tuple, Dict
from itertools import chain
import time import time
from attr import dataclass from attr import dataclass

View file

@ -88,9 +88,11 @@ class Config(BaseProxyConfig):
return re.compile(pattern, flags=flags) return re.compile(pattern, flags=flags)
@staticmethod @staticmethod
def _parse_variables(data: Dict[str, Any]) -> Dict[str, JinjaTemplate]: def _parse_variables(data: Dict[str, Any]) -> Dict[str, Any]:
return {name: JinjaTemplate(var_tpl) for name, var_tpl return {name: (JinjaTemplate(var_tpl)
in data.get("variables", {}).items()} if isinstance(var_tpl, str) and var_tpl.startswith("{{")
else var_tpl)
for name, var_tpl in data.get("variables", {}).items()}
@staticmethod @staticmethod
def _parse_content(content: Union[Dict[str, Any], str]) -> Union[Dict[str, Any], JinjaTemplate]: def _parse_content(content: Union[Dict[str, Any], str]) -> Union[Dict[str, Any], JinjaTemplate]:

View file

@ -13,7 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Match, Dict, List, Set, Union, Pattern from typing import Optional, Match, Dict, List, Set, Union, Pattern, Any
from attr import dataclass from attr import dataclass
from jinja2 import Template as JinjaTemplate from jinja2 import Template as JinjaTemplate
@ -36,7 +36,7 @@ class Rule:
not_matches: List[RPattern] not_matches: List[RPattern]
template: Template template: Template
type: Optional[EventType] type: Optional[EventType]
variables: Dict[str, JinjaTemplate] variables: Dict[str, Any]
def _check_not_match(self, body: str) -> bool: def _check_not_match(self, body: str) -> bool:
for pattern in self.not_matches: for pattern in self.not_matches:
@ -58,7 +58,9 @@ class Rule:
return None return None
async def execute(self, evt: MessageEvent, match: Match) -> None: async def execute(self, evt: MessageEvent, match: Match) -> None:
content = self.template.execute(evt=evt, rule_vars=self.variables, extra_vars = {
extra_vars={str(i): val for i, val in **{str(i): val for i, val in enumerate(match.groups())},
enumerate(match.groups())}) **match.groupdict(),
}
content = self.template.execute(evt=evt, rule_vars=self.variables, extra_vars=extra_vars)
await evt.client.send_message_event(evt.room_id, self.type or self.template.type, content) await evt.client.send_message_event(evt.room_id, self.type or self.template.type, content)

View file

@ -13,7 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, List, Optional from typing import Callable, List, Dict, Optional
import re import re
@ -22,6 +22,10 @@ class BlankMatch:
def groups() -> List[str]: def groups() -> List[str]:
return [] return []
@staticmethod
def groupdict() -> Dict[str, str]:
return {}
class SimplePattern: class SimplePattern:
_ptm = BlankMatch() _ptm = BlankMatch()

View file

@ -37,7 +37,7 @@ Index = Union[str, int, Key]
@dataclass @dataclass
class Template: class Template:
type: EventType type: EventType
variables: Dict[str, JinjaTemplate] variables: Dict[str, Any]
content: Union[Dict[str, Any], JinjaTemplate] content: Union[Dict[str, Any], JinjaTemplate]
_variable_locations: List[Tuple[Index, ...]] = None _variable_locations: List[Tuple[Index, ...]] = None
@ -78,9 +78,12 @@ class Template:
def execute(self, evt: Event, rule_vars: Dict[str, JinjaTemplate], extra_vars: Dict[str, str] def execute(self, evt: Event, rule_vars: Dict[str, JinjaTemplate], extra_vars: Dict[str, str]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
variables = {**{name: template.render(event=evt) variables = extra_vars
for name, template in chain(self.variables.items(), rule_vars.items())}, for name, template in chain(rule_vars.items(), self.variables.items()):
**extra_vars} if isinstance(template, JinjaTemplate):
variables[name] = template.render(event=evt, variables=variables)
else:
variables[name] = template
if isinstance(self.content, JinjaTemplate): if isinstance(self.content, JinjaTemplate):
raw_json = self.content.render(event=evt, **variables) raw_json = self.content.render(event=evt, **variables)
return json.loads(raw_json) return json.loads(raw_json)