Blacken and isort code, add pre-commit and CI linting

This commit is contained in:
Tulir Asokan 2023-10-05 22:22:10 +03:00
parent 3507b3b63a
commit 3ca366fea9
8 changed files with 125 additions and 52 deletions

24
.github/workflows/python-lint.yml vendored Normal file
View file

@ -0,0 +1,24 @@
name: Python lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.11"
- uses: isort/isort-action@master
with:
sortPaths: "./reactbot"
- uses: psf/black@stable
with:
src: "./reactbot"
- name: pre-commit
run: |
pip install pre-commit
pre-commit run -av trailing-whitespace
pre-commit run -av end-of-file-fixer
pre-commit run -av check-added-large-files

19
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,19 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
- id: end-of-file-fixer
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
- id: black
language_version: python3
files: ^rss/.*\.pyi?$
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
files: ^rss/.*\.pyi?$

11
pyproject.toml Normal file
View file

@ -0,0 +1,11 @@
[tool.isort]
profile = "black"
force_to_top = "typing"
from_first = true
combine_as_imports = true
known_first_party = ["mautrix", "maubot"]
line_length = 99
[tool.black]
line-length = 99
target-version = ["py38"]

View file

@ -13,16 +13,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Tuple, Dict
from typing import Dict, Tuple, Type
import time
from attr import dataclass
from mautrix.types import EventType, MessageType, UserID, RoomID
from mautrix.util.config import BaseProxyConfig
from maubot import Plugin, MessageEvent
from maubot import MessageEvent, Plugin
from maubot.handlers import event
from mautrix.types import EventType, MessageType, RoomID, UserID
from mautrix.util.config import BaseProxyConfig
from .config import Config, ConfigError
@ -73,12 +72,15 @@ class ReactBot(Plugin):
fi.max = self.config["antispam.room.max"]
fi.delay = self.config["antispam.room.delay"]
def _make_flood_info(self, for_type: str) -> 'FloodInfo':
return FloodInfo(max=self.config[f"antispam.{for_type}.max"],
def _make_flood_info(self, for_type: str) -> "FloodInfo":
return FloodInfo(
max=self.config[f"antispam.{for_type}.max"],
delay=self.config[f"antispam.{for_type}.delay"],
count=0, last_message=0)
count=0,
last_message=0,
)
def _get_flood_info(self, flood_map: dict, key: str, for_type: str) -> 'FloodInfo':
def _get_flood_info(self, flood_map: dict, key: str, for_type: str) -> "FloodInfo":
try:
return flood_map[key]
except KeyError:
@ -86,8 +88,10 @@ class ReactBot(Plugin):
return fi
def is_flood(self, evt: MessageEvent) -> bool:
return (self._get_flood_info(self.user_flood, evt.sender, "user").bump()
or self._get_flood_info(self.room_flood, evt.room_id, "room").bump())
return (
self._get_flood_info(self.user_flood, evt.sender, "user").bump()
or self._get_flood_info(self.room_flood, evt.room_id, "room").bump()
)
@event.on(EventType.ROOM_MESSAGE)
async def event_handler(self, evt: MessageEvent) -> None:

View file

@ -13,18 +13,18 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Union, Dict, Any
from typing import Any, Dict, List, Union
import re
from jinja2 import Template as JinjaStringTemplate
from jinja2.nativetypes import NativeTemplate as JinjaNativeTemplate
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from mautrix.types import EventType
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from .rule import RPattern, Rule
from .simplepattern import SimplePattern
from .template import Template
from .rule import Rule, RPattern
InputPattern = Union[str, Dict[str, str]]
@ -49,28 +49,32 @@ class Config(BaseProxyConfig):
self.rules = {}
self.default_flags = self._get_flags(self["default_flags"])
self.templates = {name: self._make_template(name, tpl)
for name, tpl in self["templates"].items()}
self.rules = {name: self._make_rule(name, rule)
for name, rule in self["rules"].items()}
self.templates = {
name: self._make_template(name, tpl) for name, tpl in self["templates"].items()
}
self.rules = {name: self._make_rule(name, rule) for name, rule in self["rules"].items()}
def _make_rule(self, name: str, rule: Dict[str, Any]) -> Rule:
try:
return Rule(rooms=set(rule.get("rooms", [])),
return Rule(
rooms=set(rule.get("rooms", [])),
not_rooms=set(rule.get("not_rooms", [])),
matches=self._compile_all(rule["matches"]),
not_matches=self._compile_all(rule.get("not_matches", [])),
type=EventType.find(rule["type"]) if "type" in rule else None,
template=self.templates[rule["template"]],
variables=self._parse_variables(rule))
variables=self._parse_variables(rule),
)
except Exception as e:
raise ConfigError(f"Failed to load {name}") from e
def _make_template(self, name: str, tpl: Dict[str, Any]) -> Template:
try:
return Template(type=EventType.find(tpl.get("type", "m.room.message")),
return Template(
type=EventType.find(tpl.get("type", "m.room.message")),
variables=self._parse_variables(tpl),
content=self._parse_content(tpl.get("content", None))).init()
content=self._parse_content(tpl.get("content", None)),
).init()
except Exception as e:
raise ConfigError(f"Failed to load {name}") from e
@ -93,13 +97,19 @@ class Config(BaseProxyConfig):
@staticmethod
def _parse_variables(data: Dict[str, Any]) -> Dict[str, Any]:
return {name: (JinjaNativeTemplate(var_tpl)
return {
name: (
JinjaNativeTemplate(var_tpl)
if isinstance(var_tpl, str) and var_tpl.startswith("{{")
else var_tpl)
for name, var_tpl in data.get("variables", {}).items()}
else var_tpl
)
for name, var_tpl in data.get("variables", {}).items()
}
@staticmethod
def _parse_content(content: Union[Dict[str, Any], str]) -> Union[Dict[str, Any], JinjaStringTemplate]:
def _parse_content(
content: Union[Dict[str, Any], str]
) -> Union[Dict[str, Any], JinjaStringTemplate]:
if not content:
return {}
elif isinstance(content, str):

View file

@ -13,16 +13,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Match, Dict, List, Set, Union, Pattern, Any
from typing import Any, Dict, List, Match, Optional, Pattern, Set, Union
from attr import dataclass
from mautrix.types import RoomID, EventType
from maubot import MessageEvent
from mautrix.types import EventType, RoomID
from .template import Template, OmitValue
from .simplepattern import SimplePattern
from .template import OmitValue, Template
RPattern = Union[Pattern, SimplePattern]
@ -59,7 +58,7 @@ class Rule:
async def execute(self, evt: MessageEvent, match: Match) -> None:
extra_vars = {
"0": match.group(0),
**{str(i+1): val for i, val in enumerate(match.groups())},
**{str(i + 1): val for i, val in enumerate(match.groups())},
**match.groupdict(),
}
content = self.template.execute(evt=evt, rule_vars=self.variables, extra_vars=extra_vars)

View file

@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, List, Dict, Optional, NamedTuple
from typing import Callable, Dict, List, NamedTuple, Optional
import re
@ -68,21 +68,22 @@ class SimplePattern:
return SimpleMatch(self.pattern)
@staticmethod
def compile(pattern: str, flags: re.RegexFlag = re.RegexFlag(0), force_raw: bool = False
) -> Optional['SimplePattern']:
def compile(
pattern: str, flags: re.RegexFlag = re.RegexFlag(0), force_raw: bool = False
) -> Optional["SimplePattern"]:
ignorecase = flags == re.IGNORECASE
s_pattern = pattern.lower() if ignorecase else pattern
esc = ""
if not force_raw:
esc = re.escape(pattern)
first, last = pattern[0], pattern[-1]
if first == '^' and last == '$' and (force_raw or esc == f"\\^{pattern[1:-1]}\\$"):
if first == "^" and last == "$" and (force_raw or esc == f"\\^{pattern[1:-1]}\\$"):
s_pattern = s_pattern[1:-1]
func = matcher_equals
elif first == '^' and (force_raw or esc == f"\\^{pattern[1:]}"):
elif first == "^" and (force_raw or esc == f"\\^{pattern[1:]}"):
s_pattern = s_pattern[1:]
func = matcher_startswith
elif last == '$' and (force_raw or esc == f"{pattern[:-1]}\\$"):
elif last == "$" and (force_raw or esc == f"{pattern[:-1]}\\$"):
s_pattern = s_pattern[:-1]
func = matcher_endswith
elif force_raw or esc == pattern:

View file

@ -13,17 +13,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, Dict, List, Tuple, Any
from typing import Any, Dict, List, Tuple, Union
from itertools import chain
import json
import copy
import json
import re
from attr import dataclass
from jinja2 import Template as JinjaStringTemplate
from jinja2.nativetypes import Template as JinjaNativeTemplate
from mautrix.types import EventType, Event
from mautrix.types import Event, EventType
class Key(str):
@ -48,7 +48,7 @@ class Template:
_variable_locations: List[Tuple[Index, ...]] = None
def init(self) -> 'Template':
def init(self) -> "Template":
self._variable_locations = []
self._map_variable_locations((), self.content)
return self
@ -80,13 +80,18 @@ class Template:
return variables[full_var_match.group(1)]
return variable_regex.sub(lambda match: str(variables[match.group(1)]), tpl)
def execute(self, evt: Event, rule_vars: Dict[str, Any], extra_vars: Dict[str, str]
def execute(
self, evt: Event, rule_vars: Dict[str, Any], extra_vars: Dict[str, str]
) -> Dict[str, Any]:
variables = extra_vars
for name, template in chain(rule_vars.items(), self.variables.items()):
if isinstance(template, JinjaNativeTemplate):
rendered_var = template.render(event=evt, variables=variables, **global_vars)
if not isinstance(rendered_var, (str, int, list, tuple, dict, bool)) and rendered_var is not None and rendered_var is not OmitValue:
if (
not isinstance(rendered_var, (str, int, list, tuple, dict, bool))
and rendered_var is not None
and rendered_var is not OmitValue
):
rendered_var = str(rendered_var)
variables[name] = rendered_var
else: