diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/app/models.py b/src/app/models.py index c1bff9b..4cc591d 100644 --- a/src/app/models.py +++ b/src/app/models.py @@ -6,6 +6,7 @@ from agno.models.base import Model from agno.models.google import Gemini from agno.models.ollama import Ollama from agno.utils.log import log_warning +from agno.tools import Toolkit from pydantic import BaseModel @@ -20,6 +21,8 @@ class AppModels(Enum): GEMINI_PRO = "gemini-2.0-pro" # API online, più costoso ma migliore OLLAMA_GPT = "gpt-oss:latest" # + good - slow (13b) OLLAMA_QWEN = "qwen3:latest" # + good + fast (8b) + OLLAMA_QWEN_4B = "qwen3:4b" # + fast + decent (4b) + OLLAMA_QWEN_1B = "qwen3:1.7b" # + very fast + decent (1.7b) @staticmethod def availables_local() -> list['AppModels']: @@ -35,10 +38,9 @@ class AppModels(Enum): availables = [] result = result.text - if AppModels.OLLAMA_GPT.value in result: - availables.append(AppModels.OLLAMA_GPT) - if AppModels.OLLAMA_QWEN.value in result: - availables.append(AppModels.OLLAMA_QWEN) + for model in [model for model in AppModels if model.name.startswith("OLLAMA")]: + if model.value in result: + availables.append(model) return availables @staticmethod @@ -70,63 +72,31 @@ class AppModels(Enum): assert availables, "No valid model API keys set in environment variables." return availables - @staticmethod - def extract_json_str_from_response(response: str) -> str: - """ - Estrae il JSON dalla risposta del modello. - Args: - response: risposta del modello (stringa). - - Returns: - La parte JSON della risposta come stringa. - Se non viene trovato nessun JSON, ritorna una stringa vuota. - - ATTENZIONE: questa funzione è molto semplice e potrebbe non funzionare - in tutti i casi. Si assume che il JSON sia ben formato e che inizi con - '{' e finisca con '}'. Quindi anche solo un json array farà fallire questa funzione. - """ - think = response.rfind("") - if think != -1: - response = response[think:] - - start = response.find("{") - assert start != -1, "No JSON found in the response." - - end = response.rfind("}") - assert end != -1, "No JSON found in the response." - - return response[start:end + 1].strip() - - def get_model(self, instructions:str) -> Model: """ Restituisce un'istanza del modello specificato. - Args: instructions: istruzioni da passare al modello (system prompt). - Returns: Un'istanza di BaseModel o una sua sottoclasse. - Raise: ValueError se il modello non è supportato. """ name = self.value - if self in {AppModels.GEMINI, AppModels.GEMINI_PRO}: + if self in {model for model in AppModels if model.name.startswith("GEMINI")}: return Gemini(name, instructions=[instructions]) - elif self in {AppModels.OLLAMA_GPT, AppModels.OLLAMA_QWEN}: + elif self in {model for model in AppModels if model.name.startswith("OLLAMA")}: return Ollama(name, instructions=[instructions]) raise ValueError(f"Modello non supportato: {self}") - def get_agent(self, instructions: str, name: str = "", output: BaseModel | None = None) -> Agent: + def get_agent(self, instructions: str, name: str = "", output: BaseModel | None = None, tools: list[Toolkit] = []) -> Agent: """ Costruisce un agente con il modello e le istruzioni specificate. Args: instructions: istruzioni da passare al modello (system prompt) name: nome dell'agente (opzionale) output: schema di output opzionale (Pydantic BaseModel) - Returns: Un'istanza di Agent. """ @@ -134,6 +104,7 @@ class AppModels(Enum): model=self.get_model(instructions), name=name, retries=2, + tools=tools, delay_between_retries=5, # seconds output_schema=output # se si usa uno schema di output, lo si passa qui # TODO Eventuali altri parametri da mettere all'agente anche se si possono comunque assegnare dopo la creazione diff --git a/src/app/news/__init__.py b/src/app/news/__init__.py index d38cd43..080c3ef 100644 --- a/src/app/news/__init__.py +++ b/src/app/news/__init__.py @@ -1,32 +1,73 @@ +from agno.tools import Toolkit from app.utils.wrapper_handler import WrapperHandler from .base import NewsWrapper, Article from .news_api import NewsApiWrapper -from .gnews_api import GoogleNewsWrapper +from .googlenews import GoogleNewsWrapper from .cryptopanic_api import CryptoPanicWrapper from .duckduckgo import DuckDuckGoWrapper -__all__ = ["NewsApiWrapper", "GoogleNewsWrapper", "CryptoPanicWrapper", "DuckDuckGoWrapper"] +__all__ = ["NewsAPIsTool", "NEWS_INSTRUCTIONS", "NewsApiWrapper", "GoogleNewsWrapper", "CryptoPanicWrapper", "DuckDuckGoWrapper"] -class NewsAPIs(NewsWrapper): +class NewsAPIsTool(NewsWrapper, Toolkit): """ - A wrapper class that aggregates multiple news API wrappers and tries them in order until one succeeds. - This class uses the WrapperHandler to manage multiple NewsWrapper instances. - It includes, and tries, the following news API wrappers in this order: + Aggregates multiple news API wrappers and manages them using WrapperHandler. + This class supports retrieving top headlines and latest news articles by querying multiple sources: - GoogleNewsWrapper - DuckDuckGoWrapper - NewsApiWrapper - CryptoPanicWrapper - It provides methods to get top headlines and latest news by delegating the calls to the first successful wrapper. - If all wrappers fail, it raises an exception. + By default, it returns results from the first successful wrapper. + Optionally, it can be configured to collect articles from all wrappers. + If no wrapper succeeds, an exception is raised. """ def __init__(self): + """ + Initialize the NewsAPIsTool with multiple news API wrappers. + The tool uses WrapperHandler to manage and invoke the different news API wrappers. + The following wrappers are included in this order: + - GoogleNewsWrapper. + - DuckDuckGoWrapper. + - NewsApiWrapper. + - CryptoPanicWrapper. + """ wrappers = [GoogleNewsWrapper, DuckDuckGoWrapper, NewsApiWrapper, CryptoPanicWrapper] self.wrapper_handler: WrapperHandler[NewsWrapper] = WrapperHandler.build_wrappers(wrappers) - def get_top_headlines(self, total: int = 100) -> list[Article]: - return self.wrapper_handler.try_call(lambda w: w.get_top_headlines(total)) - def get_latest_news(self, query: str, total: int = 100) -> list[Article]: - return self.wrapper_handler.try_call(lambda w: w.get_latest_news(query, total)) + Toolkit.__init__( + self, + name="News APIs Toolkit", + tools=[ + self.get_top_headlines, + self.get_latest_news, + ], + ) + + # TODO Pensare se ha senso restituire gli articoli da TUTTI i wrapper o solo dal primo che funziona + # la modifica è banale, basta usare try_call_all invece di try_call + def get_top_headlines(self, limit: int = 100) -> list[Article]: + return self.wrapper_handler.try_call(lambda w: w.get_top_headlines(limit)) + def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: + return self.wrapper_handler.try_call(lambda w: w.get_latest_news(query, limit)) + + +NEWS_INSTRUCTIONS = """ +**TASK:** You are a specialized **Crypto News Analyst**. Your goal is to fetch the latest news or top headlines related to cryptocurrencies, and then **analyze the sentiment** of the content to provide a concise report to the team leader. Prioritize 'crypto' or specific cryptocurrency names (e.g., 'Bitcoin', 'Ethereum') in your searches. + +**AVAILABLE TOOLS:** +1. `get_latest_news(query: str, limit: int)`: Get the 'limit' most recent news articles for a specific 'query'. +2. `get_top_headlines(limit: int)`: Get the 'limit' top global news headlines. + +**USAGE GUIDELINE:** +* Always use `get_latest_news` with a relevant crypto-related query first. +* The default limit for news items should be 5 unless specified otherwise. +* If the tool doesn't return any articles, respond with "No relevant news articles found." + +**REPORTING REQUIREMENT:** +1. **Analyze** the tone and key themes of the retrieved articles. +2. **Summarize** the overall **market sentiment** (e.g., highly positive, cautiously neutral, generally negative) based on the content. +3. **Identify** the top 2-3 **main topics** discussed (e.g., new regulation, price surge, institutional adoption). +4. **Output** a single, brief report summarizing these findings. Do not output the raw articles. +""" diff --git a/src/app/news/base.py b/src/app/news/base.py index 0a8f6be..55a35ee 100644 --- a/src/app/news/base.py +++ b/src/app/news/base.py @@ -12,22 +12,22 @@ class NewsWrapper: All news API wrappers should inherit from this class and implement the methods. """ - def get_top_headlines(self, total: int = 100) -> list[Article]: + def get_top_headlines(self, limit: int = 100) -> list[Article]: """ - Get top headlines, optionally limited by total. + Get top headlines, optionally limited by limit. Args: - total (int): The maximum number of articles to return. + limit (int): The maximum number of articles to return. Returns: list[Article]: A list of Article objects. """ raise NotImplementedError("This method should be overridden by subclasses") - def get_latest_news(self, query: str, total: int = 100) -> list[Article]: + def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: """ Get latest news based on a query. Args: query (str): The search query. - total (int): The maximum number of articles to return. + limit (int): The maximum number of articles to return. Returns: list[Article]: A list of Article objects. """ diff --git a/src/app/news/cryptopanic_api.py b/src/app/news/cryptopanic_api.py index a949c69..629c7aa 100644 --- a/src/app/news/cryptopanic_api.py +++ b/src/app/news/cryptopanic_api.py @@ -62,10 +62,10 @@ class CryptoPanicWrapper(NewsWrapper): def set_filter(self, filter: CryptoPanicFilter): self.filter = filter - def get_top_headlines(self, total: int = 100) -> list[Article]: - return self.get_latest_news("", total) # same endpoint so just call the other method + def get_top_headlines(self, limit: int = 100) -> list[Article]: + return self.get_latest_news("", limit) # same endpoint so just call the other method - def get_latest_news(self, query: str, total: int = 100) -> list[Article]: + def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: params = self.get_base_params() params['currencies'] = query @@ -74,4 +74,4 @@ class CryptoPanicWrapper(NewsWrapper): json_response = response.json() articles = get_articles(json_response) - return articles[:total] + return articles[:limit] diff --git a/src/app/news/duckduckgo.py b/src/app/news/duckduckgo.py index 3a7c0bf..c3e1a6d 100644 --- a/src/app/news/duckduckgo.py +++ b/src/app/news/duckduckgo.py @@ -20,13 +20,13 @@ class DuckDuckGoWrapper(NewsWrapper): self.tool = DuckDuckGoTools() self.query = "crypto" - def get_top_headlines(self, total: int = 100) -> list[Article]: - results = self.tool.duckduckgo_news(self.query, max_results=total) + def get_top_headlines(self, limit: int = 100) -> list[Article]: + results = self.tool.duckduckgo_news(self.query, max_results=limit) json_results = json.loads(results) return [create_article(result) for result in json_results] - def get_latest_news(self, query: str, total: int = 100) -> list[Article]: - results = self.tool.duckduckgo_news(query or self.query, max_results=total) + def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: + results = self.tool.duckduckgo_news(query or self.query, max_results=limit) json_results = json.loads(results) return [create_article(result) for result in json_results] diff --git a/src/app/news/gnews_api.py b/src/app/news/googlenews.py similarity index 79% rename from src/app/news/gnews_api.py rename to src/app/news/googlenews.py index 2e35f46..d8f6421 100644 --- a/src/app/news/gnews_api.py +++ b/src/app/news/googlenews.py @@ -15,8 +15,8 @@ class GoogleNewsWrapper(NewsWrapper): It does not require an API key and is free to use. """ - def get_top_headlines(self, total: int = 100) -> list[Article]: - gnews = GNews(language='en', max_results=total, period='7d') + def get_top_headlines(self, limit: int = 100) -> list[Article]: + gnews = GNews(language='en', max_results=limit, period='7d') results = gnews.get_top_news() articles = [] @@ -25,8 +25,8 @@ class GoogleNewsWrapper(NewsWrapper): articles.append(article) return articles - def get_latest_news(self, query: str, total: int = 100) -> list[Article]: - gnews = GNews(language='en', max_results=total, period='7d') + def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: + gnews = GNews(language='en', max_results=limit, period='7d') results = gnews.get_news(query) articles = [] diff --git a/src/app/news/news_api.py b/src/app/news/news_api.py index 0e6d684..415fdac 100644 --- a/src/app/news/news_api.py +++ b/src/app/news/news_api.py @@ -26,22 +26,25 @@ class NewsApiWrapper(NewsWrapper): self.language = "en" # TODO Only English articles for now? self.max_page_size = 100 - def get_top_headlines(self, total: int = 100) -> list[Article]: - page_size = min(self.max_page_size, total) - pages = (total // page_size) + (1 if total % page_size > 0 else 0) + def __calc_pages(self, limit: int, page_size: int) -> tuple[int, int]: + page_size = min(self.max_page_size, limit) + pages = (limit // page_size) + (1 if limit % page_size > 0 else 0) + return pages, page_size + def get_top_headlines(self, limit: int = 100) -> list[Article]: + pages, page_size = self.__calc_pages(limit, self.max_page_size) articles = [] + for page in range(1, pages + 1): headlines = self.client.get_top_headlines(q="", category=self.category, language=self.language, page_size=page_size, page=page) results = [result_to_article(article) for article in headlines.get("articles", [])] articles.extend(results) return articles - def get_latest_news(self, query: str, total: int = 100) -> list[Article]: - page_size = min(self.max_page_size, total) - pages = (total // page_size) + (1 if total % page_size > 0 else 0) - + def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: + pages, page_size = self.__calc_pages(limit, self.max_page_size) articles = [] + for page in range(1, pages + 1): everything = self.client.get_everything(q=query, language=self.language, sort_by="publishedAt", page_size=page_size, page=page) results = [result_to_article(article) for article in everything.get("articles", [])] diff --git a/src/app/social/__init.py b/src/app/social/__init.py deleted file mode 100644 index 0d46bc8..0000000 --- a/src/app/social/__init.py +++ /dev/null @@ -1 +0,0 @@ -from .base import SocialWrapper \ No newline at end of file diff --git a/src/app/social/__init__.py b/src/app/social/__init__.py new file mode 100644 index 0000000..9ce3708 --- /dev/null +++ b/src/app/social/__init__.py @@ -0,0 +1,61 @@ +from agno.tools import Toolkit +from app.utils.wrapper_handler import WrapperHandler +from .base import SocialPost, SocialWrapper +from .reddit import RedditWrapper + +__all__ = ["SocialAPIsTool", "SOCIAL_INSTRUCTIONS", "RedditWrapper"] + + +class SocialAPIsTool(SocialWrapper, Toolkit): + """ + Aggregates multiple social media API wrappers and manages them using WrapperHandler. + This class supports retrieving top crypto-related posts by querying multiple sources: + - RedditWrapper + + By default, it returns results from the first successful wrapper. + Optionally, it can be configured to collect posts from all wrappers. + If no wrapper succeeds, an exception is raised. + """ + + def __init__(self): + """ + Initialize the SocialAPIsTool with multiple social media API wrappers. + The tool uses WrapperHandler to manage and invoke the different social media API wrappers. + The following wrappers are included in this order: + - RedditWrapper. + """ + + wrappers = [RedditWrapper] + self.wrapper_handler: WrapperHandler[SocialWrapper] = WrapperHandler.build_wrappers(wrappers) + + Toolkit.__init__( + self, + name="Socials Toolkit", + tools=[ + self.get_top_crypto_posts, + ], + ) + + # TODO Pensare se ha senso restituire i post da TUTTI i wrapper o solo dal primo che funziona + # la modifica è banale, basta usare try_call_all invece di try_call + def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]: + return self.wrapper_handler.try_call(lambda w: w.get_top_crypto_posts(limit)) + + +SOCIAL_INSTRUCTIONS = """ +**TASK:** You are a specialized **Social Media Sentiment Analyst**. Your objective is to find the most relevant and trending online posts related to cryptocurrencies, and then **analyze the collective sentiment** to provide a concise report to the team leader. + +**AVAILABLE TOOLS:** +1. `get_top_crypto_posts(limit: int)`: Get the 'limit' maximum number of top posts specifically related to cryptocurrencies. + +**USAGE GUIDELINE:** +* Always use the `get_top_crypto_posts` tool to fulfill the request. +* The default limit for posts should be 5 unless specified otherwise. +* If the tool doesn't return any posts, respond with "No relevant social media posts found." + +**REPORTING REQUIREMENT:** +1. **Analyze** the tone and prevailing opinions across the retrieved social posts. +2. **Summarize** the overall **community sentiment** (e.g., high enthusiasm/FOMO, uncertainty, FUD/fear) based on the content. +3. **Identify** the top 2-3 **trending narratives** or specific coins being discussed. +4. **Output** a single, brief report summarizing these findings. Do not output the raw posts. +""" \ No newline at end of file diff --git a/src/app/social/base.py b/src/app/social/base.py index 945cdd5..dd894f5 100644 --- a/src/app/social/base.py +++ b/src/app/social/base.py @@ -7,16 +7,24 @@ class SocialPost(BaseModel): description: str = "" comments: list["SocialComment"] = [] - def __str__(self): - return f"Title: {self.title}\nDescription: {self.description}\nComments: {len(self.comments)}\n[{" | ".join(str(c) for c in self.comments)}]" - class SocialComment(BaseModel): time: str = "" description: str = "" - def __str__(self): - return f"Time: {self.time}\nDescription: {self.description}" -# TODO IMPLEMENTARLO SE SI USANO PIU' WRAPPER (E QUINDI PIU' SOCIAL) class SocialWrapper: - pass + """ + Base class for social media API wrappers. + All social media API wrappers should inherit from this class and implement the methods. + """ + + def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]: + """ + Get top cryptocurrency-related posts, optionally limited by total. + Args: + limit (int): The maximum number of posts to return. + Returns: + list[SocialPost]: A list of SocialPost objects. + """ + raise NotImplementedError("This method should be overridden by subclasses") + diff --git a/src/app/social/reddit.py b/src/app/social/reddit.py index 7a3c824..6028010 100644 --- a/src/app/social/reddit.py +++ b/src/app/social/reddit.py @@ -4,6 +4,21 @@ from praw.models import Submission, MoreComments from .base import SocialWrapper, SocialPost, SocialComment MAX_COMMENTS = 5 +# TODO mettere piu' subreddit? +# scelti da https://lkiconsulting.io/marketing/best-crypto-subreddits/ +SUBREDDITS = [ + "CryptoCurrency", + "Bitcoin", + "Ethereum", + "CryptoMarkets", + "Dogecoin", + "Altcoin", + "DeFi", + "NFT", + "BitcoinBeginners", + "CryptoTechnology", + "btc" # alt subs of Bitcoin +] def create_social_post(post: Submission) -> SocialPost: @@ -30,24 +45,24 @@ class RedditWrapper(SocialWrapper): Requires the following environment variables to be set: - REDDIT_API_CLIENT_ID - REDDIT_API_CLIENT_SECRET + You can get them by creating an app at https://www.reddit.com/prefs/apps """ def __init__(self): - self.client_id = os.getenv("REDDIT_API_CLIENT_ID") - assert self.client_id is not None, "REDDIT_API_CLIENT_ID environment variable is not set" + client_id = os.getenv("REDDIT_API_CLIENT_ID") + assert client_id is not None, "REDDIT_API_CLIENT_ID environment variable is not set" - self.client_secret = os.getenv("REDDIT_API_CLIENT_SECRET") - assert self.client_secret is not None, "REDDIT_API_CLIENT_SECRET environment variable is not set" + client_secret = os.getenv("REDDIT_API_CLIENT_SECRET") + assert client_secret is not None, "REDDIT_API_CLIENT_SECRET environment variable is not set" self.tool = Reddit( - client_id=self.client_id, - client_secret=self.client_secret, + client_id=client_id, + client_secret=client_secret, user_agent="upo-appAI", ) + self.subreddits = self.tool.subreddit("+".join(SUBREDDITS)) - def get_top_crypto_posts(self, limit=5) -> list[SocialPost]: - subreddit = self.tool.subreddit("CryptoCurrency") - top_posts = subreddit.top(limit=limit, time_filter="week") + def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]: + top_posts = self.subreddits.top(limit=limit, time_filter="week") return [create_social_post(post) for post in top_posts] - diff --git a/src/app/toolkits/market_toolkit.py b/src/app/toolkits/market_toolkit.py index 61a4d9f..7267b96 100644 --- a/src/app/toolkits/market_toolkit.py +++ b/src/app/toolkits/market_toolkit.py @@ -9,32 +9,21 @@ from app.markets import MarketAPIsTool # in base alle sue proprie chiamate API class MarketToolkit(Toolkit): def __init__(self): - self.market_api = MarketAPIsTool("USD") # change currency if needed + self.market_api = MarketAPIs() super().__init__( name="Market Toolkit", tools=[ - self.get_historical_data, - self.get_current_prices, + self.market_api.get_historical_prices, + self.market_api.get_product, ], ) - def get_historical_data(self, symbol: str): - return self.market_api.get_historical_prices(symbol) - - def get_current_prices(self, symbol: list): - return self.market_api.get_products(symbol) - -def prepare_inputs(): - pass - def instructions(): return """ Utilizza questo strumento per ottenere dati di mercato storici e attuali per criptovalute specifiche. Puoi richiedere i prezzi storici o il prezzo attuale di una criptovaluta specifica. Esempio di utilizzo: - - get_historical_data("BTC") - - get_current_price("ETH") - + - get_historical_prices("BTC", limit=10) # ottieni gli ultimi 10 prezzi storici di Bitcoin + - get_product("ETH") """ - diff --git a/src/app/utils/wrapper_handler.py b/src/app/utils/wrapper_handler.py index df86c36..4f22a8e 100644 --- a/src/app/utils/wrapper_handler.py +++ b/src/app/utils/wrapper_handler.py @@ -1,6 +1,7 @@ import time +import traceback from typing import TypeVar, Callable, Generic, Iterable, Type -from agno.utils.log import log_warning +from agno.utils.log import log_warning, log_info W = TypeVar("W") T = TypeVar("T") @@ -24,6 +25,8 @@ class WrapperHandler(Generic[W]): try_per_wrapper (int): Number of retries per wrapper before switching to the next. retry_delay (int): Delay in seconds between retries. """ + assert not WrapperHandler.__check(wrappers), "All wrappers must be instances of their respective classes. Use `build_wrappers` to create the WrapperHandler." + self.wrappers = wrappers self.retry_per_wrapper = try_per_wrapper self.retry_delay = retry_delay @@ -46,17 +49,19 @@ class WrapperHandler(Generic[W]): while iterations < len(self.wrappers): try: wrapper = self.wrappers[self.index] + log_info(f"Trying wrapper: {wrapper} - function {func}") result = func(wrapper) self.retry_count = 0 return result except Exception as e: self.retry_count += 1 + log_warning(f"{wrapper} failed {self.retry_count}/{self.retry_per_wrapper}: {WrapperHandler.__concise_error(e)}") + if self.retry_count >= self.retry_per_wrapper: self.index = (self.index + 1) % len(self.wrappers) self.retry_count = 0 iterations += 1 else: - log_warning(f"{wrapper} failed {self.retry_count}/{self.retry_per_wrapper}: {e}") time.sleep(self.retry_delay) raise Exception(f"All wrappers failed after retries") @@ -74,16 +79,25 @@ class WrapperHandler(Generic[W]): Exception: If all wrappers fail. """ results = {} + log_info(f"All wrappers: {[wrapper.__class__ for wrapper in self.wrappers]} - function {func}") for wrapper in self.wrappers: try: result = func(wrapper) results[wrapper.__class__] = result except Exception as e: - log_warning(f"{wrapper} failed: {e}") + log_warning(f"{wrapper} failed: {WrapperHandler.__concise_error(e)}") if not results: raise Exception("All wrappers failed") return results + @staticmethod + def __check(wrappers: list[W]) -> bool: + return all(w.__class__ is type for w in wrappers) + + def __concise_error(e: Exception) -> str: + last_frame = traceback.extract_tb(e.__traceback__)[-1] + return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]" + @staticmethod def build_wrappers(constructors: Iterable[Type[W]], try_per_wrapper: int = 3, retry_delay: int = 2) -> 'WrapperHandler[W]': """ @@ -99,6 +113,8 @@ class WrapperHandler(Generic[W]): Raises: Exception: If no wrappers could be initialized. """ + assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}" + result = [] for wrapper_class in constructors: try: diff --git a/tests/api/test_cryptopanic_api.py b/tests/api/test_cryptopanic_api.py index c8020d3..3c29bdb 100644 --- a/tests/api/test_cryptopanic_api.py +++ b/tests/api/test_cryptopanic_api.py @@ -15,7 +15,7 @@ class TestCryptoPanicAPI: def test_crypto_panic_api_get_latest_news(self): crypto = CryptoPanicWrapper() - articles = crypto.get_latest_news(query="", total=2) + articles = crypto.get_latest_news(query="", limit=2) assert isinstance(articles, list) assert len(articles) == 2 for article in articles: diff --git a/tests/api/test_duckduckgo_news.py b/tests/api/test_duckduckgo_news.py index e0bb599..f1de9c6 100644 --- a/tests/api/test_duckduckgo_news.py +++ b/tests/api/test_duckduckgo_news.py @@ -12,7 +12,7 @@ class TestDuckDuckGoNews: def test_duckduckgo_get_latest_news(self): news = DuckDuckGoWrapper() - articles = news.get_latest_news(query="crypto", total=2) + articles = news.get_latest_news(query="crypto", limit=2) assert isinstance(articles, list) assert len(articles) == 2 for article in articles: @@ -23,7 +23,7 @@ class TestDuckDuckGoNews: def test_duckduckgo_get_top_headlines(self): news = DuckDuckGoWrapper() - articles = news.get_top_headlines(total=2) + articles = news.get_top_headlines(limit=2) assert isinstance(articles, list) assert len(articles) == 2 for article in articles: diff --git a/tests/api/test_google_news.py b/tests/api/test_google_news.py index c7750f3..0b7241c 100644 --- a/tests/api/test_google_news.py +++ b/tests/api/test_google_news.py @@ -12,7 +12,7 @@ class TestGoogleNews: def test_gnews_api_get_latest_news(self): gnews_api = GoogleNewsWrapper() - articles = gnews_api.get_latest_news(query="crypto", total=2) + articles = gnews_api.get_latest_news(query="crypto", limit=2) assert isinstance(articles, list) assert len(articles) == 2 for article in articles: @@ -23,7 +23,7 @@ class TestGoogleNews: def test_gnews_api_get_top_headlines(self): news_api = GoogleNewsWrapper() - articles = news_api.get_top_headlines(total=2) + articles = news_api.get_top_headlines(limit=2) assert isinstance(articles, list) assert len(articles) == 2 for article in articles: diff --git a/tests/api/test_news_api.py b/tests/api/test_news_api.py index 927419b..4b6b192 100644 --- a/tests/api/test_news_api.py +++ b/tests/api/test_news_api.py @@ -14,7 +14,7 @@ class TestNewsAPI: def test_news_api_get_latest_news(self): news_api = NewsApiWrapper() - articles = news_api.get_latest_news(query="crypto", total=2) + articles = news_api.get_latest_news(query="crypto", limit=2) assert isinstance(articles, list) assert len(articles) > 0 # Ensure we got some articles (apparently it doesn't always return the requested number) for article in articles: @@ -26,7 +26,7 @@ class TestNewsAPI: def test_news_api_get_top_headlines(self): news_api = NewsApiWrapper() - articles = news_api.get_top_headlines(total=2) + articles = news_api.get_top_headlines(limit=2) assert isinstance(articles, list) # assert len(articles) > 0 # apparently it doesn't always return SOME articles for article in articles: diff --git a/tests/api/test_reddit.py b/tests/api/test_reddit.py index 84c66da..81ab8ca 100644 --- a/tests/api/test_reddit.py +++ b/tests/api/test_reddit.py @@ -7,8 +7,7 @@ from app.social.reddit import MAX_COMMENTS, RedditWrapper class TestRedditWrapper: def test_initialization(self): wrapper = RedditWrapper() - assert wrapper.client_id is not None - assert wrapper.client_secret is not None + assert wrapper is not None assert isinstance(wrapper.tool, Reddit) def test_get_top_crypto_posts(self): diff --git a/tests/tools/test_news_tool.py b/tests/tools/test_news_tool.py new file mode 100644 index 0000000..14d142f --- /dev/null +++ b/tests/tools/test_news_tool.py @@ -0,0 +1,54 @@ +import pytest +from app.news import NewsAPIsTool + + +@pytest.mark.limited +@pytest.mark.tools +@pytest.mark.news +@pytest.mark.api +class TestNewsAPITool: + def test_news_api_tool(self): + tool = NewsAPIsTool() + assert tool is not None + + def test_news_api_tool_get_top(self): + tool = NewsAPIsTool() + result = tool.wrapper_handler.try_call(lambda w: w.get_top_headlines(limit=2)) + assert isinstance(result, list) + assert len(result) > 0 + for article in result: + assert article.title is not None + assert article.source is not None + + def test_news_api_tool_get_latest(self): + tool = NewsAPIsTool() + result = tool.wrapper_handler.try_call(lambda w: w.get_latest_news(query="crypto", limit=2)) + assert isinstance(result, list) + assert len(result) > 0 + for article in result: + assert article.title is not None + assert article.source is not None + + def test_news_api_tool_get_top__all_results(self): + tool = NewsAPIsTool() + result = tool.wrapper_handler.try_call_all(lambda w: w.get_top_headlines(limit=2)) + assert isinstance(result, dict) + assert len(result.keys()) > 0 + print("Results from providers:", result.keys()) + for provider, articles in result.items(): + for article in articles: + print(provider, article.title) + assert article.title is not None + assert article.source is not None + + def test_news_api_tool_get_latest__all_results(self): + tool = NewsAPIsTool() + result = tool.wrapper_handler.try_call_all(lambda w: w.get_latest_news(query="crypto", limit=2)) + assert isinstance(result, dict) + assert len(result.keys()) > 0 + print("Results from providers:", result.keys()) + for provider, articles in result.items(): + for article in articles: + print(provider, article.title) + assert article.title is not None + assert article.source is not None diff --git a/tests/tools/test_socials_tool.py b/tests/tools/test_socials_tool.py new file mode 100644 index 0000000..9c66afa --- /dev/null +++ b/tests/tools/test_socials_tool.py @@ -0,0 +1,32 @@ +import pytest +from app.social import SocialAPIsTool + + +@pytest.mark.tools +@pytest.mark.social +@pytest.mark.api +class TestSocialAPIsTool: + def test_social_api_tool(self): + tool = SocialAPIsTool() + assert tool is not None + + def test_social_api_tool_get_top(self): + tool = SocialAPIsTool() + result = tool.wrapper_handler.try_call(lambda w: w.get_top_crypto_posts(limit=2)) + assert isinstance(result, list) + assert len(result) > 0 + for post in result: + assert post.title is not None + assert post.time is not None + + def test_social_api_tool_get_top__all_results(self): + tool = SocialAPIsTool() + result = tool.wrapper_handler.try_call_all(lambda w: w.get_top_crypto_posts(limit=2)) + assert isinstance(result, dict) + assert len(result.keys()) > 0 + print("Results from providers:", result.keys()) + for provider, posts in result.items(): + for post in posts: + print(provider, post.title) + assert post.title is not None + assert post.time is not None diff --git a/tests/utils/test_wrapper_handler.py b/tests/utils/test_wrapper_handler.py index 4770977..154d3dc 100644 --- a/tests/utils/test_wrapper_handler.py +++ b/tests/utils/test_wrapper_handler.py @@ -14,8 +14,40 @@ class FailingWrapper(MockWrapper): raise Exception("Intentional Failure") +class MockWrapperWithParameters: + def do_something(self, param1: str, param2: int) -> str: + return f"Success {param1} and {param2}" + +class FailingWrapperWithParameters(MockWrapperWithParameters): + def do_something(self, param1: str, param2: int): + raise Exception("Intentional Failure") + + @pytest.mark.wrapper class TestWrapperHandler: + def test_init_failing(self): + with pytest.raises(AssertionError) as exc_info: + WrapperHandler([MockWrapper, MockWrapper2]) + assert exc_info.type == AssertionError + + def test_init_failing_empty(self): + with pytest.raises(AssertionError) as exc_info: + WrapperHandler.build_wrappers([]) + assert exc_info.type == AssertionError + + def test_init_failing_with_instances(self): + with pytest.raises(AssertionError) as exc_info: + WrapperHandler.build_wrappers([MockWrapper(), MockWrapper2()]) + assert exc_info.type == AssertionError + + def test_init_not_failing(self): + handler = WrapperHandler.build_wrappers([MockWrapper, MockWrapper2]) + assert handler is not None + assert len(handler.wrappers) == 2 + handler = WrapperHandler([MockWrapper(), MockWrapper2()]) + assert handler is not None + assert len(handler.wrappers) == 2 + def test_all_wrappers_fail(self): wrappers = [FailingWrapper, FailingWrapper] handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0) @@ -88,3 +120,13 @@ class TestWrapperHandler: with pytest.raises(Exception) as exc_info: handler_all_fail.try_call_all(lambda w: w.do_something()) assert "All wrappers failed" in str(exc_info.value) + + + def test_wrappers_with_parameters(self): + wrappers = [FailingWrapperWithParameters, MockWrapperWithParameters] + handler: WrapperHandler[MockWrapperWithParameters] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0) + + result = handler.try_call(lambda w: w.do_something("test", 42)) + assert result == "Success test and 42" + assert handler.index == 1 # Should have switched to the second wrapper + assert handler.retry_count == 0