maubot_reactbot/reactbot/template.py

97 lines
3.6 KiB
Python

# reminder - A maubot plugin that reacts to messages that match predefined rules.
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, Dict, List, Tuple, Any
from itertools import chain
import json
import copy
import re
from attr import dataclass
from jinja2 import Template as JinjaTemplate
from mautrix.types import EventType, Event
class Key(str):
pass
variable_regex = re.compile(r"\$\${([0-9A-Za-z-_]+)}")
Index = Union[str, int, Key]
@dataclass
class Template:
type: EventType
variables: Dict[str, Any]
content: Union[Dict[str, Any], JinjaTemplate]
_variable_locations: List[Tuple[Index, ...]] = None
def init(self) -> 'Template':
self._variable_locations = []
self._map_variable_locations((), self.content)
return self
def _map_variable_locations(self, path: Tuple[Index, ...], data: Any) -> None:
if isinstance(data, list):
for i, v in enumerate(data):
self._map_variable_locations((*path, i), v)
elif isinstance(data, dict):
for k, v in data.items():
if variable_regex.search(k):
self._variable_locations.append((*path, Key(k)))
self._map_variable_locations((*path, k), v)
elif isinstance(data, str):
if variable_regex.search(data):
self._variable_locations.append(path)
@classmethod
def _recurse(cls, content: Any, path: Tuple[Index, ...]) -> Any:
if len(path) == 0:
return content
return cls._recurse(content[path[0]], path[1:])
@staticmethod
def _replace_variables(tpl: str, variables: Dict[str, Any]) -> str:
full_var_match = variable_regex.fullmatch(tpl)
if full_var_match:
# Whole field is a single variable, just return the value to allow non-string types.
return variables[full_var_match.group(1)]
return variable_regex.sub(lambda match: str(variables[match.group(1)]), tpl)
def execute(self, evt: Event, rule_vars: Dict[str, Any], extra_vars: Dict[str, str]
) -> Dict[str, Any]:
variables = extra_vars
for name, template in chain(rule_vars.items(), self.variables.items()):
if isinstance(template, JinjaTemplate):
variables[name] = template.render(event=evt, variables=variables)
else:
variables[name] = template
if isinstance(self.content, JinjaTemplate):
raw_json = self.content.render(event=evt, **variables)
return json.loads(raw_json)
content = copy.deepcopy(self.content)
for path in self._variable_locations:
data: Dict[str, Any] = self._recurse(content, path[:-1])
key = path[-1]
if isinstance(key, Key):
key = str(key)
data[self._replace_variables(key, variables)] = data.pop(key)
else:
data[key] = self._replace_variables(data[key], variables)
return content