diff --git a/flanabot/bots/scraper_bot.py b/flanabot/bots/scraper_bot.py index 81244f9..3821a5c 100644 --- a/flanabot/bots/scraper_bot.py +++ b/flanabot/bots/scraper_bot.py @@ -7,8 +7,7 @@ from collections import defaultdict from typing import Iterable import flanautils -from flanaapis import instagram, reddit, tiktok, twitter, yt_dlp_wrapper -from flanaapis.scraping import constants as flanaapis_constants +from flanaapis import RedditMediaNotFoundError, instagram, reddit, tiktok, twitter, yt_dlp_wrapper from flanautils import Media, MediaType, OrderedSet, return_if_first_empty from multibot import MultiBot, RegisteredCallback, SendError, constants as multibot_constants, reply @@ -45,7 +44,7 @@ class ScraperBot(MultiBot, ABC): self.register(self._on_song_info, constants.KEYWORDS['song_info']) @staticmethod - async def _find_ids(text: str): + async def _find_ids(text: str) -> tuple[OrderedSet[str], ...]: return ( twitter.find_ids(text), instagram.find_ids(text), @@ -143,44 +142,67 @@ class ScraperBot(MultiBot, ABC): timeout_for_media: int | float = None ) -> OrderedSet[Media]: medias = OrderedSet() + exceptions: list[Exception] = [] - ids = await self._find_ids(message.text) - media_urls = () + ids = [] + media_urls = [] + for text_part in message.text.split(): + for i, platform_ids in enumerate(await self._find_ids(text_part)): + try: + ids[i] |= platform_ids + except IndexError: + ids.append(platform_ids) + if not any(ids) and flanautils.find_urls(text_part): + if force: + media_urls.append(text_part) + else: + if not any(domain.lower() in text_part for domain in multibot_constants.GIF_DOMAINS): + media_urls.append(text_part) - if ( - not any(ids) - and - ( - not (media_urls := flanautils.find_urls(message.text)) - or - ( - not force_gif_download - and - any(domain in url for url in media_urls for domain in flanaapis_constants.YT_DLP_WRAPPER_DISCARDED_DOMAINS) - ) - ) - ): + if not any(ids) and not media_urls: return medias bot_state_message = await self.send(random.choice(constants.SCRAPING_PHRASES), message) tweet_ids, instagram_ids, reddit_ids, tiktok_users_and_ids, tiktok_download_urls = ids + + try: + reddit_medias = await reddit.get_medias(reddit_ids, 'h264', 'mp4', force, audio_only, timeout_for_media) + except RedditMediaNotFoundError as e: + exceptions.append(e) + reddit_medias = () + reddit_urls = [] + for reddit_media in reddit_medias: + if reddit_media.source: + medias.add(reddit_media) + else: + reddit_urls.append(reddit_media.url) + if force: + media_urls.extend(reddit_urls) + else: + for reddit_url in reddit_urls: + for domain in multibot_constants.GIF_DOMAINS: + if domain.lower() in reddit_url: + medias.add(Media(reddit_url, MediaType.GIF, source=domain)) + break + else: + media_urls.append(reddit_url) + gather_result = asyncio.gather( twitter.get_medias(tweet_ids, audio_only), instagram.get_medias(instagram_ids, audio_only), - reddit.get_medias(reddit_ids, 'h264', 'mp4', audio_only, force_gif_download, timeout_for_media), - tiktok.get_medias(tiktok_users_and_ids, tiktok_download_urls, audio_only, 'h264', 'mp4', timeout_for_media), - yt_dlp_wrapper.get_medias(media_urls, 'h264', 'mp4', audio_only, force_gif_download, timeout_for_media), + tiktok.get_medias(tiktok_users_and_ids, tiktok_download_urls, 'h264', 'mp4', force, audio_only, timeout_for_media), + yt_dlp_wrapper.get_medias(media_urls, 'h264', 'mp4', force, audio_only, timeout_for_media), return_exceptions=True ) await gather_result await self.delete_message(bot_state_message) - medias, exceptions = flanautils.filter_exceptions(gather_result.result()) - await self._manage_exceptions(exceptions, message, print_traceback=True) + gather_medias, gather_exceptions = flanautils.filter_exceptions(gather_result.result()) + await self._manage_exceptions(exceptions + gather_exceptions, message, print_traceback=True) - return OrderedSet(*medias) + return medias | gather_medias # ---------------------------------------------- # # HANDLERS # @@ -224,6 +246,7 @@ class ScraperBot(MultiBot, ABC): # -------------------------------------------------------- # # -------------------- PUBLIC METHODS -------------------- # # -------------------------------------------------------- # + @return_if_first_empty(([], 0), exclude_self_types='ScraperBot', globals_=globals()) async def send_medias(self, medias: OrderedSet[Media], message: Message, send_song_info=False) -> tuple[list[Message], int]: sended_media_messages = [] fails = 0 @@ -254,6 +277,8 @@ class ScraperBot(MultiBot, ABC): ), multibot_constants.PARSER_MIN_SCORE_DEFAULT ) + and + flanautils.remove_symbols(word).lower() not in (str(self.id), self.name.lower()) )] ) if user_text: @@ -279,11 +304,11 @@ class ScraperBot(MultiBot, ABC): if send_song_info and media.song_info: await self.send_song_info(media.song_info, message) - if fails and sended_info_message: - if fails == len(medias): + if fails == len(medias): + if sended_info_message: await self.delete_message(sended_info_message) - if user_text_bot_message: - await self.delete_message(user_text_bot_message) + if user_text_bot_message: + await self.delete_message(user_text_bot_message) if bot_state_message: await self.delete_message(bot_state_message)