Blacken and isort code, add pre-commit and CI linting

This commit is contained in:
Tulir Asokan 2023-10-05 22:22:10 +03:00
parent 3507b3b63a
commit 3ca366fea9
8 changed files with 125 additions and 52 deletions

24
.github/workflows/python-lint.yml vendored Normal file
View file

@ -0,0 +1,24 @@
name: Python lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.11"
- uses: isort/isort-action@master
with:
sortPaths: "./reactbot"
- uses: psf/black@stable
with:
src: "./reactbot"
- name: pre-commit
run: |
pip install pre-commit
pre-commit run -av trailing-whitespace
pre-commit run -av end-of-file-fixer
pre-commit run -av check-added-large-files

19
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,19 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
- id: end-of-file-fixer
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
- id: black
language_version: python3
files: ^rss/.*\.pyi?$
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
files: ^rss/.*\.pyi?$

11
pyproject.toml Normal file
View file

@ -0,0 +1,11 @@
[tool.isort]
profile = "black"
force_to_top = "typing"
from_first = true
combine_as_imports = true
known_first_party = ["mautrix", "maubot"]
line_length = 99
[tool.black]
line-length = 99
target-version = ["py38"]

View file

@ -13,16 +13,15 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Tuple, Dict from typing import Dict, Tuple, Type
import time import time
from attr import dataclass from attr import dataclass
from mautrix.types import EventType, MessageType, UserID, RoomID from maubot import MessageEvent, Plugin
from mautrix.util.config import BaseProxyConfig
from maubot import Plugin, MessageEvent
from maubot.handlers import event from maubot.handlers import event
from mautrix.types import EventType, MessageType, RoomID, UserID
from mautrix.util.config import BaseProxyConfig
from .config import Config, ConfigError from .config import Config, ConfigError
@ -73,12 +72,15 @@ class ReactBot(Plugin):
fi.max = self.config["antispam.room.max"] fi.max = self.config["antispam.room.max"]
fi.delay = self.config["antispam.room.delay"] fi.delay = self.config["antispam.room.delay"]
def _make_flood_info(self, for_type: str) -> 'FloodInfo': def _make_flood_info(self, for_type: str) -> "FloodInfo":
return FloodInfo(max=self.config[f"antispam.{for_type}.max"], return FloodInfo(
max=self.config[f"antispam.{for_type}.max"],
delay=self.config[f"antispam.{for_type}.delay"], delay=self.config[f"antispam.{for_type}.delay"],
count=0, last_message=0) count=0,
last_message=0,
)
def _get_flood_info(self, flood_map: dict, key: str, for_type: str) -> 'FloodInfo': def _get_flood_info(self, flood_map: dict, key: str, for_type: str) -> "FloodInfo":
try: try:
return flood_map[key] return flood_map[key]
except KeyError: except KeyError:
@ -86,8 +88,10 @@ class ReactBot(Plugin):
return fi return fi
def is_flood(self, evt: MessageEvent) -> bool: def is_flood(self, evt: MessageEvent) -> bool:
return (self._get_flood_info(self.user_flood, evt.sender, "user").bump() return (
or self._get_flood_info(self.room_flood, evt.room_id, "room").bump()) self._get_flood_info(self.user_flood, evt.sender, "user").bump()
or self._get_flood_info(self.room_flood, evt.room_id, "room").bump()
)
@event.on(EventType.ROOM_MESSAGE) @event.on(EventType.ROOM_MESSAGE)
async def event_handler(self, evt: MessageEvent) -> None: async def event_handler(self, evt: MessageEvent) -> None:

View file

@ -13,18 +13,18 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Union, Dict, Any from typing import Any, Dict, List, Union
import re import re
from jinja2 import Template as JinjaStringTemplate from jinja2 import Template as JinjaStringTemplate
from jinja2.nativetypes import NativeTemplate as JinjaNativeTemplate from jinja2.nativetypes import NativeTemplate as JinjaNativeTemplate
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from mautrix.types import EventType from mautrix.types import EventType
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from .rule import RPattern, Rule
from .simplepattern import SimplePattern from .simplepattern import SimplePattern
from .template import Template from .template import Template
from .rule import Rule, RPattern
InputPattern = Union[str, Dict[str, str]] InputPattern = Union[str, Dict[str, str]]
@ -49,28 +49,32 @@ class Config(BaseProxyConfig):
self.rules = {} self.rules = {}
self.default_flags = self._get_flags(self["default_flags"]) self.default_flags = self._get_flags(self["default_flags"])
self.templates = {name: self._make_template(name, tpl) self.templates = {
for name, tpl in self["templates"].items()} name: self._make_template(name, tpl) for name, tpl in self["templates"].items()
self.rules = {name: self._make_rule(name, rule) }
for name, rule in self["rules"].items()} self.rules = {name: self._make_rule(name, rule) for name, rule in self["rules"].items()}
def _make_rule(self, name: str, rule: Dict[str, Any]) -> Rule: def _make_rule(self, name: str, rule: Dict[str, Any]) -> Rule:
try: try:
return Rule(rooms=set(rule.get("rooms", [])), return Rule(
rooms=set(rule.get("rooms", [])),
not_rooms=set(rule.get("not_rooms", [])), not_rooms=set(rule.get("not_rooms", [])),
matches=self._compile_all(rule["matches"]), matches=self._compile_all(rule["matches"]),
not_matches=self._compile_all(rule.get("not_matches", [])), not_matches=self._compile_all(rule.get("not_matches", [])),
type=EventType.find(rule["type"]) if "type" in rule else None, type=EventType.find(rule["type"]) if "type" in rule else None,
template=self.templates[rule["template"]], template=self.templates[rule["template"]],
variables=self._parse_variables(rule)) variables=self._parse_variables(rule),
)
except Exception as e: except Exception as e:
raise ConfigError(f"Failed to load {name}") from e raise ConfigError(f"Failed to load {name}") from e
def _make_template(self, name: str, tpl: Dict[str, Any]) -> Template: def _make_template(self, name: str, tpl: Dict[str, Any]) -> Template:
try: try:
return Template(type=EventType.find(tpl.get("type", "m.room.message")), return Template(
type=EventType.find(tpl.get("type", "m.room.message")),
variables=self._parse_variables(tpl), variables=self._parse_variables(tpl),
content=self._parse_content(tpl.get("content", None))).init() content=self._parse_content(tpl.get("content", None)),
).init()
except Exception as e: except Exception as e:
raise ConfigError(f"Failed to load {name}") from e raise ConfigError(f"Failed to load {name}") from e
@ -93,13 +97,19 @@ class Config(BaseProxyConfig):
@staticmethod @staticmethod
def _parse_variables(data: Dict[str, Any]) -> Dict[str, Any]: def _parse_variables(data: Dict[str, Any]) -> Dict[str, Any]:
return {name: (JinjaNativeTemplate(var_tpl) return {
name: (
JinjaNativeTemplate(var_tpl)
if isinstance(var_tpl, str) and var_tpl.startswith("{{") if isinstance(var_tpl, str) and var_tpl.startswith("{{")
else var_tpl) else var_tpl
for name, var_tpl in data.get("variables", {}).items()} )
for name, var_tpl in data.get("variables", {}).items()
}
@staticmethod @staticmethod
def _parse_content(content: Union[Dict[str, Any], str]) -> Union[Dict[str, Any], JinjaStringTemplate]: def _parse_content(
content: Union[Dict[str, Any], str]
) -> Union[Dict[str, Any], JinjaStringTemplate]:
if not content: if not content:
return {} return {}
elif isinstance(content, str): elif isinstance(content, str):

View file

@ -13,16 +13,15 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Match, Dict, List, Set, Union, Pattern, Any from typing import Any, Dict, List, Match, Optional, Pattern, Set, Union
from attr import dataclass from attr import dataclass
from mautrix.types import RoomID, EventType
from maubot import MessageEvent from maubot import MessageEvent
from mautrix.types import EventType, RoomID
from .template import Template, OmitValue
from .simplepattern import SimplePattern from .simplepattern import SimplePattern
from .template import OmitValue, Template
RPattern = Union[Pattern, SimplePattern] RPattern = Union[Pattern, SimplePattern]

View file

@ -13,7 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, List, Dict, Optional, NamedTuple from typing import Callable, Dict, List, NamedTuple, Optional
import re import re
@ -68,21 +68,22 @@ class SimplePattern:
return SimpleMatch(self.pattern) return SimpleMatch(self.pattern)
@staticmethod @staticmethod
def compile(pattern: str, flags: re.RegexFlag = re.RegexFlag(0), force_raw: bool = False def compile(
) -> Optional['SimplePattern']: pattern: str, flags: re.RegexFlag = re.RegexFlag(0), force_raw: bool = False
) -> Optional["SimplePattern"]:
ignorecase = flags == re.IGNORECASE ignorecase = flags == re.IGNORECASE
s_pattern = pattern.lower() if ignorecase else pattern s_pattern = pattern.lower() if ignorecase else pattern
esc = "" esc = ""
if not force_raw: if not force_raw:
esc = re.escape(pattern) esc = re.escape(pattern)
first, last = pattern[0], pattern[-1] first, last = pattern[0], pattern[-1]
if first == '^' and last == '$' and (force_raw or esc == f"\\^{pattern[1:-1]}\\$"): if first == "^" and last == "$" and (force_raw or esc == f"\\^{pattern[1:-1]}\\$"):
s_pattern = s_pattern[1:-1] s_pattern = s_pattern[1:-1]
func = matcher_equals func = matcher_equals
elif first == '^' and (force_raw or esc == f"\\^{pattern[1:]}"): elif first == "^" and (force_raw or esc == f"\\^{pattern[1:]}"):
s_pattern = s_pattern[1:] s_pattern = s_pattern[1:]
func = matcher_startswith func = matcher_startswith
elif last == '$' and (force_raw or esc == f"{pattern[:-1]}\\$"): elif last == "$" and (force_raw or esc == f"{pattern[:-1]}\\$"):
s_pattern = s_pattern[:-1] s_pattern = s_pattern[:-1]
func = matcher_endswith func = matcher_endswith
elif force_raw or esc == pattern: elif force_raw or esc == pattern:

View file

@ -13,17 +13,17 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, Dict, List, Tuple, Any from typing import Any, Dict, List, Tuple, Union
from itertools import chain from itertools import chain
import json
import copy import copy
import json
import re import re
from attr import dataclass from attr import dataclass
from jinja2 import Template as JinjaStringTemplate from jinja2 import Template as JinjaStringTemplate
from jinja2.nativetypes import Template as JinjaNativeTemplate from jinja2.nativetypes import Template as JinjaNativeTemplate
from mautrix.types import EventType, Event from mautrix.types import Event, EventType
class Key(str): class Key(str):
@ -48,7 +48,7 @@ class Template:
_variable_locations: List[Tuple[Index, ...]] = None _variable_locations: List[Tuple[Index, ...]] = None
def init(self) -> 'Template': def init(self) -> "Template":
self._variable_locations = [] self._variable_locations = []
self._map_variable_locations((), self.content) self._map_variable_locations((), self.content)
return self return self
@ -80,13 +80,18 @@ class Template:
return variables[full_var_match.group(1)] return variables[full_var_match.group(1)]
return variable_regex.sub(lambda match: str(variables[match.group(1)]), tpl) return variable_regex.sub(lambda match: str(variables[match.group(1)]), tpl)
def execute(self, evt: Event, rule_vars: Dict[str, Any], extra_vars: Dict[str, str] def execute(
self, evt: Event, rule_vars: Dict[str, Any], extra_vars: Dict[str, str]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
variables = extra_vars variables = extra_vars
for name, template in chain(rule_vars.items(), self.variables.items()): for name, template in chain(rule_vars.items(), self.variables.items()):
if isinstance(template, JinjaNativeTemplate): if isinstance(template, JinjaNativeTemplate):
rendered_var = template.render(event=evt, variables=variables, **global_vars) rendered_var = template.render(event=evt, variables=variables, **global_vars)
if not isinstance(rendered_var, (str, int, list, tuple, dict, bool)) and rendered_var is not None and rendered_var is not OmitValue: if (
not isinstance(rendered_var, (str, int, list, tuple, dict, bool))
and rendered_var is not None
and rendered_var is not OmitValue
):
rendered_var = str(rendered_var) rendered_var = str(rendered_var)
variables[name] = rendered_var variables[name] = rendered_var
else: else: