Refactor ScraperBot.send_medias
This commit is contained in:
@@ -7,8 +7,7 @@ from collections import defaultdict
|
||||
from typing import Iterable
|
||||
|
||||
import flanautils
|
||||
from flanaapis import instagram, reddit, tiktok, twitter, yt_dlp_wrapper
|
||||
from flanaapis.scraping import constants as flanaapis_constants
|
||||
from flanaapis import RedditMediaNotFoundError, instagram, reddit, tiktok, twitter, yt_dlp_wrapper
|
||||
from flanautils import Media, MediaType, OrderedSet, return_if_first_empty
|
||||
from multibot import MultiBot, RegisteredCallback, SendError, constants as multibot_constants, reply
|
||||
|
||||
@@ -45,7 +44,7 @@ class ScraperBot(MultiBot, ABC):
|
||||
self.register(self._on_song_info, constants.KEYWORDS['song_info'])
|
||||
|
||||
@staticmethod
|
||||
async def _find_ids(text: str):
|
||||
async def _find_ids(text: str) -> tuple[OrderedSet[str], ...]:
|
||||
return (
|
||||
twitter.find_ids(text),
|
||||
instagram.find_ids(text),
|
||||
@@ -143,44 +142,67 @@ class ScraperBot(MultiBot, ABC):
|
||||
timeout_for_media: int | float = None
|
||||
) -> OrderedSet[Media]:
|
||||
medias = OrderedSet()
|
||||
exceptions: list[Exception] = []
|
||||
|
||||
ids = await self._find_ids(message.text)
|
||||
media_urls = ()
|
||||
ids = []
|
||||
media_urls = []
|
||||
for text_part in message.text.split():
|
||||
for i, platform_ids in enumerate(await self._find_ids(text_part)):
|
||||
try:
|
||||
ids[i] |= platform_ids
|
||||
except IndexError:
|
||||
ids.append(platform_ids)
|
||||
if not any(ids) and flanautils.find_urls(text_part):
|
||||
if force:
|
||||
media_urls.append(text_part)
|
||||
else:
|
||||
if not any(domain.lower() in text_part for domain in multibot_constants.GIF_DOMAINS):
|
||||
media_urls.append(text_part)
|
||||
|
||||
if (
|
||||
not any(ids)
|
||||
and
|
||||
(
|
||||
not (media_urls := flanautils.find_urls(message.text))
|
||||
or
|
||||
(
|
||||
not force_gif_download
|
||||
and
|
||||
any(domain in url for url in media_urls for domain in flanaapis_constants.YT_DLP_WRAPPER_DISCARDED_DOMAINS)
|
||||
)
|
||||
)
|
||||
):
|
||||
if not any(ids) and not media_urls:
|
||||
return medias
|
||||
|
||||
bot_state_message = await self.send(random.choice(constants.SCRAPING_PHRASES), message)
|
||||
|
||||
tweet_ids, instagram_ids, reddit_ids, tiktok_users_and_ids, tiktok_download_urls = ids
|
||||
|
||||
try:
|
||||
reddit_medias = await reddit.get_medias(reddit_ids, 'h264', 'mp4', force, audio_only, timeout_for_media)
|
||||
except RedditMediaNotFoundError as e:
|
||||
exceptions.append(e)
|
||||
reddit_medias = ()
|
||||
reddit_urls = []
|
||||
for reddit_media in reddit_medias:
|
||||
if reddit_media.source:
|
||||
medias.add(reddit_media)
|
||||
else:
|
||||
reddit_urls.append(reddit_media.url)
|
||||
if force:
|
||||
media_urls.extend(reddit_urls)
|
||||
else:
|
||||
for reddit_url in reddit_urls:
|
||||
for domain in multibot_constants.GIF_DOMAINS:
|
||||
if domain.lower() in reddit_url:
|
||||
medias.add(Media(reddit_url, MediaType.GIF, source=domain))
|
||||
break
|
||||
else:
|
||||
media_urls.append(reddit_url)
|
||||
|
||||
gather_result = asyncio.gather(
|
||||
twitter.get_medias(tweet_ids, audio_only),
|
||||
instagram.get_medias(instagram_ids, audio_only),
|
||||
reddit.get_medias(reddit_ids, 'h264', 'mp4', audio_only, force_gif_download, timeout_for_media),
|
||||
tiktok.get_medias(tiktok_users_and_ids, tiktok_download_urls, audio_only, 'h264', 'mp4', timeout_for_media),
|
||||
yt_dlp_wrapper.get_medias(media_urls, 'h264', 'mp4', audio_only, force_gif_download, timeout_for_media),
|
||||
tiktok.get_medias(tiktok_users_and_ids, tiktok_download_urls, 'h264', 'mp4', force, audio_only, timeout_for_media),
|
||||
yt_dlp_wrapper.get_medias(media_urls, 'h264', 'mp4', force, audio_only, timeout_for_media),
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
await gather_result
|
||||
await self.delete_message(bot_state_message)
|
||||
|
||||
medias, exceptions = flanautils.filter_exceptions(gather_result.result())
|
||||
await self._manage_exceptions(exceptions, message, print_traceback=True)
|
||||
gather_medias, gather_exceptions = flanautils.filter_exceptions(gather_result.result())
|
||||
await self._manage_exceptions(exceptions + gather_exceptions, message, print_traceback=True)
|
||||
|
||||
return OrderedSet(*medias)
|
||||
return medias | gather_medias
|
||||
|
||||
# ---------------------------------------------- #
|
||||
# HANDLERS #
|
||||
@@ -224,6 +246,7 @@ class ScraperBot(MultiBot, ABC):
|
||||
# -------------------------------------------------------- #
|
||||
# -------------------- PUBLIC METHODS -------------------- #
|
||||
# -------------------------------------------------------- #
|
||||
@return_if_first_empty(([], 0), exclude_self_types='ScraperBot', globals_=globals())
|
||||
async def send_medias(self, medias: OrderedSet[Media], message: Message, send_song_info=False) -> tuple[list[Message], int]:
|
||||
sended_media_messages = []
|
||||
fails = 0
|
||||
@@ -254,6 +277,8 @@ class ScraperBot(MultiBot, ABC):
|
||||
),
|
||||
multibot_constants.PARSER_MIN_SCORE_DEFAULT
|
||||
)
|
||||
and
|
||||
flanautils.remove_symbols(word).lower() not in (str(self.id), self.name.lower())
|
||||
)]
|
||||
)
|
||||
if user_text:
|
||||
@@ -279,11 +304,11 @@ class ScraperBot(MultiBot, ABC):
|
||||
if send_song_info and media.song_info:
|
||||
await self.send_song_info(media.song_info, message)
|
||||
|
||||
if fails and sended_info_message:
|
||||
if fails == len(medias):
|
||||
if fails == len(medias):
|
||||
if sended_info_message:
|
||||
await self.delete_message(sended_info_message)
|
||||
if user_text_bot_message:
|
||||
await self.delete_message(user_text_bot_message)
|
||||
if user_text_bot_message:
|
||||
await self.delete_message(user_text_bot_message)
|
||||
|
||||
if bot_state_message:
|
||||
await self.delete_message(bot_state_message)
|
||||
|
||||
Reference in New Issue
Block a user