Expose named capture groups and earlier variables in jinja variables (ref #5)

This commit is contained in:
Tulir Asokan 2020-12-11 19:59:19 +02:00
parent 0790b429b3
commit b213481d7d
5 changed files with 24 additions and 14 deletions

View file

@ -14,7 +14,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Tuple, Dict
from itertools import chain
import time
from attr import dataclass

View file

@ -88,9 +88,11 @@ class Config(BaseProxyConfig):
return re.compile(pattern, flags=flags)
@staticmethod
def _parse_variables(data: Dict[str, Any]) -> Dict[str, JinjaTemplate]:
return {name: JinjaTemplate(var_tpl) for name, var_tpl
in data.get("variables", {}).items()}
def _parse_variables(data: Dict[str, Any]) -> Dict[str, Any]:
return {name: (JinjaTemplate(var_tpl)
if isinstance(var_tpl, str) and var_tpl.startswith("{{")
else var_tpl)
for name, var_tpl in data.get("variables", {}).items()}
@staticmethod
def _parse_content(content: Union[Dict[str, Any], str]) -> Union[Dict[str, Any], JinjaTemplate]:

View file

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Match, Dict, List, Set, Union, Pattern
from typing import Optional, Match, Dict, List, Set, Union, Pattern, Any
from attr import dataclass
from jinja2 import Template as JinjaTemplate
@ -36,7 +36,7 @@ class Rule:
not_matches: List[RPattern]
template: Template
type: Optional[EventType]
variables: Dict[str, JinjaTemplate]
variables: Dict[str, Any]
def _check_not_match(self, body: str) -> bool:
for pattern in self.not_matches:
@ -58,7 +58,9 @@ class Rule:
return None
async def execute(self, evt: MessageEvent, match: Match) -> None:
content = self.template.execute(evt=evt, rule_vars=self.variables,
extra_vars={str(i): val for i, val in
enumerate(match.groups())})
extra_vars = {
**{str(i): val for i, val in enumerate(match.groups())},
**match.groupdict(),
}
content = self.template.execute(evt=evt, rule_vars=self.variables, extra_vars=extra_vars)
await evt.client.send_message_event(evt.room_id, self.type or self.template.type, content)

View file

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, List, Optional
from typing import Callable, List, Dict, Optional
import re
@ -22,6 +22,10 @@ class BlankMatch:
def groups() -> List[str]:
return []
@staticmethod
def groupdict() -> Dict[str, str]:
return {}
class SimplePattern:
_ptm = BlankMatch()

View file

@ -37,7 +37,7 @@ Index = Union[str, int, Key]
@dataclass
class Template:
type: EventType
variables: Dict[str, JinjaTemplate]
variables: Dict[str, Any]
content: Union[Dict[str, Any], JinjaTemplate]
_variable_locations: List[Tuple[Index, ...]] = None
@ -78,9 +78,12 @@ class Template:
def execute(self, evt: Event, rule_vars: Dict[str, JinjaTemplate], extra_vars: Dict[str, str]
) -> Dict[str, Any]:
variables = {**{name: template.render(event=evt)
for name, template in chain(self.variables.items(), rule_vars.items())},
**extra_vars}
variables = extra_vars
for name, template in chain(rule_vars.items(), self.variables.items()):
if isinstance(template, JinjaTemplate):
variables[name] = template.render(event=evt, variables=variables)
else:
variables[name] = template
if isinstance(self.content, JinjaTemplate):
raw_json = self.content.render(event=evt, **variables)
return json.loads(raw_json)