Refactor penalty system
This commit is contained in:
@@ -492,7 +492,7 @@ class FlanaBot(MultiBot, ABC):
|
||||
|
||||
async def _on_ready(self):
|
||||
await super()._on_ready()
|
||||
await flanautils.do_every(constants.CHECK_PUNISHMENTS_EVERY_SECONDS, Punishment.check_olds, self._unpunish, self.platform)
|
||||
await flanautils.do_every(constants.CHECK_PUNISHMENTS_EVERY_SECONDS, self.check_old_punishments)
|
||||
|
||||
@inline(False)
|
||||
async def _on_recover_message(self, message: Message):
|
||||
@@ -797,13 +797,53 @@ class FlanaBot(MultiBot, ABC):
|
||||
# -------------------------------------------------------- #
|
||||
# -------------------- PUBLIC METHODS -------------------- #
|
||||
# -------------------------------------------------------- #
|
||||
async def check_old_punishments(self):
|
||||
punishments = Punishment.find({'platform': self.platform.value}, lazy=True)
|
||||
|
||||
for punishment in punishments:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
if not punishment.until or now < punishment.until:
|
||||
continue
|
||||
|
||||
await self._remove_penalty(punishment, self._unpunish, delete=False)
|
||||
if punishment.is_active:
|
||||
punishment.is_active = False
|
||||
punishment.last_update = now
|
||||
punishment.save()
|
||||
|
||||
if punishment.last_update + constants.PUNISHMENTS_RESET_TIME <= now:
|
||||
if punishment.level == 1:
|
||||
punishment.delete()
|
||||
else:
|
||||
punishment.level -= 1
|
||||
punishment.last_update = now
|
||||
punishment.save()
|
||||
|
||||
async def is_punished(self, user: int | str | User, group_: int | str | Chat | Message) -> bool:
|
||||
pass
|
||||
|
||||
async def punish(self, user: int | str | User, group_: int | str | Chat | Message, time: int | datetime.timedelta = None, message: Message = None):
|
||||
async def punish(
|
||||
self,
|
||||
user: int | str | User,
|
||||
group_: int | str | Chat | Message,
|
||||
time: int | datetime.timedelta = None,
|
||||
message: Message = None
|
||||
):
|
||||
# noinspection PyTypeChecker
|
||||
punishment = Punishment(self.platform, self.get_user_id(user), self.get_group_id(group_), time)
|
||||
await punishment.apply(self._punish, self._unpunish, message)
|
||||
punishment.pull_from_database(overwrite_fields=('level',), exclude_fields=('until',))
|
||||
punishment.level += 1
|
||||
|
||||
try:
|
||||
await self._punish(punishment.user_id, punishment.group_id)
|
||||
except BadRoleError as e:
|
||||
if message and message.chat.original_object:
|
||||
await self._manage_exceptions(e, message)
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
punishment.save(pull_exclude_fields=('until',))
|
||||
await self._unpenalize_later(punishment, self._unpunish, message)
|
||||
|
||||
async def send_bye(self, message: Message) -> multibot_constants.ORIGINAL_MESSAGE:
|
||||
return await self.send(random.choice((*constants.BYE_PHRASES, flanautils.CommonWords.random_time_greeting())), message)
|
||||
@@ -870,4 +910,4 @@ class FlanaBot(MultiBot, ABC):
|
||||
async def unpunish(self, user: int | str | User, group_: int | str | Chat | Message, message: Message = None):
|
||||
# noinspection PyTypeChecker
|
||||
punishment = Punishment(self.platform, self.get_user_id(user), self.get_group_id(group_))
|
||||
await punishment.remove(self._unpunish, message)
|
||||
await self._remove_penalty(punishment, self._unpunish, message)
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
__all__ = ['Punishment']
|
||||
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from multibot.models import Platform, PunishmentBase
|
||||
|
||||
from flanabot import constants
|
||||
from flanabot.models.message import Message
|
||||
from multibot.models import Penalty
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class Punishment(PunishmentBase):
|
||||
class Punishment(Penalty):
|
||||
collection_name = 'punishment'
|
||||
|
||||
level: int = 0
|
||||
@@ -20,32 +16,3 @@ class Punishment(PunishmentBase):
|
||||
self_vars = super()._mongo_repr()
|
||||
self_vars['level'] = self.level
|
||||
return self_vars
|
||||
|
||||
async def apply(self, punishment_method: Callable, unpunishment_method: Callable, message: Message = None):
|
||||
self.pull_from_database(overwrite_fields=('level',), exclude_fields=('until',))
|
||||
self.level += 1
|
||||
|
||||
await super().apply(punishment_method, unpunishment_method, message)
|
||||
|
||||
@classmethod
|
||||
async def check_olds(cls, unpunishment_method: Callable, platform: Platform):
|
||||
punishments = cls.find({'platform': platform.value})
|
||||
|
||||
for punishment in punishments:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
if not punishment.until or now < punishment.until:
|
||||
continue
|
||||
|
||||
await punishment.remove(unpunishment_method, delete=False)
|
||||
if punishment.is_active:
|
||||
punishment.is_active = False
|
||||
punishment.last_update = now
|
||||
punishment.save()
|
||||
|
||||
if punishment.last_update + constants.PUNISHMENTS_RESET_TIME <= now:
|
||||
if punishment.level == 1:
|
||||
punishment.delete()
|
||||
else:
|
||||
punishment.level -= 1
|
||||
punishment.last_update = now
|
||||
punishment.save()
|
||||
|
||||
Reference in New Issue
Block a user