diff --git a/flanabot/bots/flana_bot.py b/flanabot/bots/flana_bot.py index 18a9cbd..c499a7f 100644 --- a/flanabot/bots/flana_bot.py +++ b/flanabot/bots/flana_bot.py @@ -492,7 +492,7 @@ class FlanaBot(MultiBot, ABC): async def _on_ready(self): await super()._on_ready() - await flanautils.do_every(constants.CHECK_PUNISHMENTS_EVERY_SECONDS, Punishment.check_olds, self._unpunish, self.platform) + await flanautils.do_every(constants.CHECK_PUNISHMENTS_EVERY_SECONDS, self.check_old_punishments) @inline(False) async def _on_recover_message(self, message: Message): @@ -797,13 +797,53 @@ class FlanaBot(MultiBot, ABC): # -------------------------------------------------------- # # -------------------- PUBLIC METHODS -------------------- # # -------------------------------------------------------- # + async def check_old_punishments(self): + punishments = Punishment.find({'platform': self.platform.value}, lazy=True) + + for punishment in punishments: + now = datetime.datetime.now(datetime.timezone.utc) + if not punishment.until or now < punishment.until: + continue + + await self._remove_penalty(punishment, self._unpunish, delete=False) + if punishment.is_active: + punishment.is_active = False + punishment.last_update = now + punishment.save() + + if punishment.last_update + constants.PUNISHMENTS_RESET_TIME <= now: + if punishment.level == 1: + punishment.delete() + else: + punishment.level -= 1 + punishment.last_update = now + punishment.save() + async def is_punished(self, user: int | str | User, group_: int | str | Chat | Message) -> bool: pass - async def punish(self, user: int | str | User, group_: int | str | Chat | Message, time: int | datetime.timedelta = None, message: Message = None): + async def punish( + self, + user: int | str | User, + group_: int | str | Chat | Message, + time: int | datetime.timedelta = None, + message: Message = None + ): # noinspection PyTypeChecker punishment = Punishment(self.platform, self.get_user_id(user), self.get_group_id(group_), time) - await punishment.apply(self._punish, self._unpunish, message) + punishment.pull_from_database(overwrite_fields=('level',), exclude_fields=('until',)) + punishment.level += 1 + + try: + await self._punish(punishment.user_id, punishment.group_id) + except BadRoleError as e: + if message and message.chat.original_object: + await self._manage_exceptions(e, message) + else: + raise e + else: + punishment.save(pull_exclude_fields=('until',)) + await self._unpenalize_later(punishment, self._unpunish, message) async def send_bye(self, message: Message) -> multibot_constants.ORIGINAL_MESSAGE: return await self.send(random.choice((*constants.BYE_PHRASES, flanautils.CommonWords.random_time_greeting())), message) @@ -870,4 +910,4 @@ class FlanaBot(MultiBot, ABC): async def unpunish(self, user: int | str | User, group_: int | str | Chat | Message, message: Message = None): # noinspection PyTypeChecker punishment = Punishment(self.platform, self.get_user_id(user), self.get_group_id(group_)) - await punishment.remove(self._unpunish, message) + await self._remove_penalty(punishment, self._unpunish, message) diff --git a/flanabot/models/punishment.py b/flanabot/models/punishment.py index c7f8df4..1b1fbfe 100644 --- a/flanabot/models/punishment.py +++ b/flanabot/models/punishment.py @@ -1,17 +1,13 @@ __all__ = ['Punishment'] -import datetime from dataclasses import dataclass -from typing import Any, Callable +from typing import Any -from multibot.models import Platform, PunishmentBase - -from flanabot import constants -from flanabot.models.message import Message +from multibot.models import Penalty @dataclass(eq=False) -class Punishment(PunishmentBase): +class Punishment(Penalty): collection_name = 'punishment' level: int = 0 @@ -20,32 +16,3 @@ class Punishment(PunishmentBase): self_vars = super()._mongo_repr() self_vars['level'] = self.level return self_vars - - async def apply(self, punishment_method: Callable, unpunishment_method: Callable, message: Message = None): - self.pull_from_database(overwrite_fields=('level',), exclude_fields=('until',)) - self.level += 1 - - await super().apply(punishment_method, unpunishment_method, message) - - @classmethod - async def check_olds(cls, unpunishment_method: Callable, platform: Platform): - punishments = cls.find({'platform': platform.value}) - - for punishment in punishments: - now = datetime.datetime.now(datetime.timezone.utc) - if not punishment.until or now < punishment.until: - continue - - await punishment.remove(unpunishment_method, delete=False) - if punishment.is_active: - punishment.is_active = False - punishment.last_update = now - punishment.save() - - if punishment.last_update + constants.PUNISHMENTS_RESET_TIME <= now: - if punishment.level == 1: - punishment.delete() - else: - punishment.level -= 1 - punishment.last_update = now - punishment.save()