Refactor ScraperBot.send_medias

This commit is contained in:
AlberLC
2022-12-21 05:25:38 +01:00
parent 9344aa2d16
commit cf6844e7a1

View File

@@ -7,8 +7,7 @@ from collections import defaultdict
from typing import Iterable
import flanautils
from flanaapis import instagram, reddit, tiktok, twitter, yt_dlp_wrapper
from flanaapis.scraping import constants as flanaapis_constants
from flanaapis import RedditMediaNotFoundError, instagram, reddit, tiktok, twitter, yt_dlp_wrapper
from flanautils import Media, MediaType, OrderedSet, return_if_first_empty
from multibot import MultiBot, RegisteredCallback, SendError, constants as multibot_constants, reply
@@ -45,7 +44,7 @@ class ScraperBot(MultiBot, ABC):
self.register(self._on_song_info, constants.KEYWORDS['song_info'])
@staticmethod
async def _find_ids(text: str):
async def _find_ids(text: str) -> tuple[OrderedSet[str], ...]:
return (
twitter.find_ids(text),
instagram.find_ids(text),
@@ -143,44 +142,67 @@ class ScraperBot(MultiBot, ABC):
timeout_for_media: int | float = None
) -> OrderedSet[Media]:
medias = OrderedSet()
exceptions: list[Exception] = []
ids = await self._find_ids(message.text)
media_urls = ()
ids = []
media_urls = []
for text_part in message.text.split():
for i, platform_ids in enumerate(await self._find_ids(text_part)):
try:
ids[i] |= platform_ids
except IndexError:
ids.append(platform_ids)
if not any(ids) and flanautils.find_urls(text_part):
if force:
media_urls.append(text_part)
else:
if not any(domain.lower() in text_part for domain in multibot_constants.GIF_DOMAINS):
media_urls.append(text_part)
if (
not any(ids)
and
(
not (media_urls := flanautils.find_urls(message.text))
or
(
not force_gif_download
and
any(domain in url for url in media_urls for domain in flanaapis_constants.YT_DLP_WRAPPER_DISCARDED_DOMAINS)
)
)
):
if not any(ids) and not media_urls:
return medias
bot_state_message = await self.send(random.choice(constants.SCRAPING_PHRASES), message)
tweet_ids, instagram_ids, reddit_ids, tiktok_users_and_ids, tiktok_download_urls = ids
try:
reddit_medias = await reddit.get_medias(reddit_ids, 'h264', 'mp4', force, audio_only, timeout_for_media)
except RedditMediaNotFoundError as e:
exceptions.append(e)
reddit_medias = ()
reddit_urls = []
for reddit_media in reddit_medias:
if reddit_media.source:
medias.add(reddit_media)
else:
reddit_urls.append(reddit_media.url)
if force:
media_urls.extend(reddit_urls)
else:
for reddit_url in reddit_urls:
for domain in multibot_constants.GIF_DOMAINS:
if domain.lower() in reddit_url:
medias.add(Media(reddit_url, MediaType.GIF, source=domain))
break
else:
media_urls.append(reddit_url)
gather_result = asyncio.gather(
twitter.get_medias(tweet_ids, audio_only),
instagram.get_medias(instagram_ids, audio_only),
reddit.get_medias(reddit_ids, 'h264', 'mp4', audio_only, force_gif_download, timeout_for_media),
tiktok.get_medias(tiktok_users_and_ids, tiktok_download_urls, audio_only, 'h264', 'mp4', timeout_for_media),
yt_dlp_wrapper.get_medias(media_urls, 'h264', 'mp4', audio_only, force_gif_download, timeout_for_media),
tiktok.get_medias(tiktok_users_and_ids, tiktok_download_urls, 'h264', 'mp4', force, audio_only, timeout_for_media),
yt_dlp_wrapper.get_medias(media_urls, 'h264', 'mp4', force, audio_only, timeout_for_media),
return_exceptions=True
)
await gather_result
await self.delete_message(bot_state_message)
medias, exceptions = flanautils.filter_exceptions(gather_result.result())
await self._manage_exceptions(exceptions, message, print_traceback=True)
gather_medias, gather_exceptions = flanautils.filter_exceptions(gather_result.result())
await self._manage_exceptions(exceptions + gather_exceptions, message, print_traceback=True)
return OrderedSet(*medias)
return medias | gather_medias
# ---------------------------------------------- #
# HANDLERS #
@@ -224,6 +246,7 @@ class ScraperBot(MultiBot, ABC):
# -------------------------------------------------------- #
# -------------------- PUBLIC METHODS -------------------- #
# -------------------------------------------------------- #
@return_if_first_empty(([], 0), exclude_self_types='ScraperBot', globals_=globals())
async def send_medias(self, medias: OrderedSet[Media], message: Message, send_song_info=False) -> tuple[list[Message], int]:
sended_media_messages = []
fails = 0
@@ -254,6 +277,8 @@ class ScraperBot(MultiBot, ABC):
),
multibot_constants.PARSER_MIN_SCORE_DEFAULT
)
and
flanautils.remove_symbols(word).lower() not in (str(self.id), self.name.lower())
)]
)
if user_text:
@@ -279,11 +304,11 @@ class ScraperBot(MultiBot, ABC):
if send_song_info and media.song_info:
await self.send_song_info(media.song_info, message)
if fails and sended_info_message:
if fails == len(medias):
if fails == len(medias):
if sended_info_message:
await self.delete_message(sended_info_message)
if user_text_bot_message:
await self.delete_message(user_text_bot_message)
if user_text_bot_message:
await self.delete_message(user_text_bot_message)
if bot_state_message:
await self.delete_message(bot_state_message)