2 news api #10
@@ -6,6 +6,7 @@ from agno.models.base import Model
|
|||||||
from agno.models.google import Gemini
|
from agno.models.google import Gemini
|
||||||
from agno.models.ollama import Ollama
|
from agno.models.ollama import Ollama
|
||||||
from agno.utils.log import log_warning
|
from agno.utils.log import log_warning
|
||||||
|
from agno.tools import Toolkit
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@@ -20,6 +21,8 @@ class AppModels(Enum):
|
|||||||
GEMINI_PRO = "gemini-2.0-pro" # API online, più costoso ma migliore
|
GEMINI_PRO = "gemini-2.0-pro" # API online, più costoso ma migliore
|
||||||
OLLAMA_GPT = "gpt-oss:latest" # + good - slow (13b)
|
OLLAMA_GPT = "gpt-oss:latest" # + good - slow (13b)
|
||||||
OLLAMA_QWEN = "qwen3:latest" # + good + fast (8b)
|
OLLAMA_QWEN = "qwen3:latest" # + good + fast (8b)
|
||||||
|
OLLAMA_QWEN_4B = "qwen3:4b" # + fast + decent (4b)
|
||||||
|
OLLAMA_QWEN_1B = "qwen3:1.7b" # + very fast + decent (1.7b)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def availables_local() -> list['AppModels']:
|
def availables_local() -> list['AppModels']:
|
||||||
@@ -35,10 +38,9 @@ class AppModels(Enum):
|
|||||||
|
|
||||||
availables = []
|
availables = []
|
||||||
result = result.text
|
result = result.text
|
||||||
if AppModels.OLLAMA_GPT.value in result:
|
for model in [model for model in AppModels if model.name.startswith("OLLAMA")]:
|
||||||
availables.append(AppModels.OLLAMA_GPT)
|
if model.value in result:
|
||||||
if AppModels.OLLAMA_QWEN.value in result:
|
availables.append(model)
|
||||||
availables.append(AppModels.OLLAMA_QWEN)
|
|
||||||
return availables
|
return availables
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -70,63 +72,31 @@ class AppModels(Enum):
|
|||||||
assert availables, "No valid model API keys set in environment variables."
|
assert availables, "No valid model API keys set in environment variables."
|
||||||
return availables
|
return availables
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_json_str_from_response(response: str) -> str:
|
|
||||||
"""
|
|
||||||
Estrae il JSON dalla risposta del modello.
|
|
||||||
Args:
|
|
||||||
response: risposta del modello (stringa).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
La parte JSON della risposta come stringa.
|
|
||||||
Se non viene trovato nessun JSON, ritorna una stringa vuota.
|
|
||||||
|
|
||||||
ATTENZIONE: questa funzione è molto semplice e potrebbe non funzionare
|
|
||||||
in tutti i casi. Si assume che il JSON sia ben formato e che inizi con
|
|
||||||
'{' e finisca con '}'. Quindi anche solo un json array farà fallire questa funzione.
|
|
||||||
"""
|
|
||||||
think = response.rfind("</think>")
|
|
||||||
if think != -1:
|
|
||||||
response = response[think:]
|
|
||||||
|
|
||||||
start = response.find("{")
|
|
||||||
assert start != -1, "No JSON found in the response."
|
|
||||||
|
|
||||||
end = response.rfind("}")
|
|
||||||
assert end != -1, "No JSON found in the response."
|
|
||||||
|
|
||||||
return response[start:end + 1].strip()
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(self, instructions:str) -> Model:
|
def get_model(self, instructions:str) -> Model:
|
||||||
"""
|
"""
|
||||||
Restituisce un'istanza del modello specificato.
|
Restituisce un'istanza del modello specificato.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: istruzioni da passare al modello (system prompt).
|
instructions: istruzioni da passare al modello (system prompt).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Un'istanza di BaseModel o una sua sottoclasse.
|
Un'istanza di BaseModel o una sua sottoclasse.
|
||||||
|
|
||||||
Raise:
|
Raise:
|
||||||
ValueError se il modello non è supportato.
|
ValueError se il modello non è supportato.
|
||||||
"""
|
"""
|
||||||
name = self.value
|
name = self.value
|
||||||
if self in {AppModels.GEMINI, AppModels.GEMINI_PRO}:
|
if self in {model for model in AppModels if model.name.startswith("GEMINI")}:
|
||||||
return Gemini(name, instructions=[instructions])
|
return Gemini(name, instructions=[instructions])
|
||||||
elif self in {AppModels.OLLAMA_GPT, AppModels.OLLAMA_QWEN}:
|
elif self in {model for model in AppModels if model.name.startswith("OLLAMA")}:
|
||||||
return Ollama(name, instructions=[instructions])
|
return Ollama(name, instructions=[instructions])
|
||||||
|
|
||||||
raise ValueError(f"Modello non supportato: {self}")
|
raise ValueError(f"Modello non supportato: {self}")
|
||||||
|
|
||||||
def get_agent(self, instructions: str, name: str = "", output: BaseModel | None = None) -> Agent:
|
def get_agent(self, instructions: str, name: str = "", output: BaseModel | None = None, tools: list[Toolkit] = []) -> Agent:
|
||||||
"""
|
"""
|
||||||
Costruisce un agente con il modello e le istruzioni specificate.
|
Costruisce un agente con il modello e le istruzioni specificate.
|
||||||
Args:
|
Args:
|
||||||
instructions: istruzioni da passare al modello (system prompt)
|
instructions: istruzioni da passare al modello (system prompt)
|
||||||
name: nome dell'agente (opzionale)
|
name: nome dell'agente (opzionale)
|
||||||
output: schema di output opzionale (Pydantic BaseModel)
|
output: schema di output opzionale (Pydantic BaseModel)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Un'istanza di Agent.
|
Un'istanza di Agent.
|
||||||
"""
|
"""
|
||||||
@@ -134,6 +104,7 @@ class AppModels(Enum):
|
|||||||
model=self.get_model(instructions),
|
model=self.get_model(instructions),
|
||||||
name=name,
|
name=name,
|
||||||
retries=2,
|
retries=2,
|
||||||
|
tools=tools,
|
||||||
delay_between_retries=5, # seconds
|
delay_between_retries=5, # seconds
|
||||||
output_schema=output # se si usa uno schema di output, lo si passa qui
|
output_schema=output # se si usa uno schema di output, lo si passa qui
|
||||||
# TODO Eventuali altri parametri da mettere all'agente anche se si possono comunque assegnare dopo la creazione
|
# TODO Eventuali altri parametri da mettere all'agente anche se si possono comunque assegnare dopo la creazione
|
||||||
|
|||||||
@@ -1,32 +1,73 @@
|
|||||||
|
from agno.tools import Toolkit
|
||||||
from app.utils.wrapper_handler import WrapperHandler
|
from app.utils.wrapper_handler import WrapperHandler
|
||||||
from .base import NewsWrapper, Article
|
from .base import NewsWrapper, Article
|
||||||
from .news_api import NewsApiWrapper
|
from .news_api import NewsApiWrapper
|
||||||
from .gnews_api import GoogleNewsWrapper
|
from .googlenews import GoogleNewsWrapper
|
||||||
from .cryptopanic_api import CryptoPanicWrapper
|
from .cryptopanic_api import CryptoPanicWrapper
|
||||||
from .duckduckgo import DuckDuckGoWrapper
|
from .duckduckgo import DuckDuckGoWrapper
|
||||||
|
|
||||||
__all__ = ["NewsApiWrapper", "GoogleNewsWrapper", "CryptoPanicWrapper", "DuckDuckGoWrapper"]
|
__all__ = ["NewsAPIsTool", "NEWS_INSTRUCTIONS", "NewsApiWrapper", "GoogleNewsWrapper", "CryptoPanicWrapper", "DuckDuckGoWrapper"]
|
||||||
|
|
||||||
|
|
||||||
class NewsAPIs(NewsWrapper):
|
class NewsAPIsTool(NewsWrapper, Toolkit):
|
||||||
"""
|
"""
|
||||||
A wrapper class that aggregates multiple news API wrappers and tries them in order until one succeeds.
|
Aggregates multiple news API wrappers and manages them using WrapperHandler.
|
||||||
This class uses the WrapperHandler to manage multiple NewsWrapper instances.
|
This class supports retrieving top headlines and latest news articles by querying multiple sources:
|
||||||
It includes, and tries, the following news API wrappers in this order:
|
|
||||||
- GoogleNewsWrapper
|
- GoogleNewsWrapper
|
||||||
- DuckDuckGoWrapper
|
- DuckDuckGoWrapper
|
||||||
- NewsApiWrapper
|
- NewsApiWrapper
|
||||||
- CryptoPanicWrapper
|
- CryptoPanicWrapper
|
||||||
|
|
||||||
It provides methods to get top headlines and latest news by delegating the calls to the first successful wrapper.
|
By default, it returns results from the first successful wrapper.
|
||||||
If all wrappers fail, it raises an exception.
|
Optionally, it can be configured to collect articles from all wrappers.
|
||||||
|
If no wrapper succeeds, an exception is raised.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the NewsAPIsTool with multiple news API wrappers.
|
||||||
|
The tool uses WrapperHandler to manage and invoke the different news API wrappers.
|
||||||
|
The following wrappers are included in this order:
|
||||||
|
- GoogleNewsWrapper.
|
||||||
|
- DuckDuckGoWrapper.
|
||||||
|
- NewsApiWrapper.
|
||||||
|
- CryptoPanicWrapper.
|
||||||
|
"""
|
||||||
wrappers = [GoogleNewsWrapper, DuckDuckGoWrapper, NewsApiWrapper, CryptoPanicWrapper]
|
wrappers = [GoogleNewsWrapper, DuckDuckGoWrapper, NewsApiWrapper, CryptoPanicWrapper]
|
||||||
self.wrapper_handler: WrapperHandler[NewsWrapper] = WrapperHandler.build_wrappers(wrappers)
|
self.wrapper_handler: WrapperHandler[NewsWrapper] = WrapperHandler.build_wrappers(wrappers)
|
||||||
|
|
||||||
def get_top_headlines(self, total: int = 100) -> list[Article]:
|
Toolkit.__init__(
|
||||||
return self.wrapper_handler.try_call(lambda w: w.get_top_headlines(total))
|
self,
|
||||||
def get_latest_news(self, query: str, total: int = 100) -> list[Article]:
|
name="News APIs Toolkit",
|
||||||
return self.wrapper_handler.try_call(lambda w: w.get_latest_news(query, total))
|
tools=[
|
||||||
|
self.get_top_headlines,
|
||||||
|
self.get_latest_news,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO Pensare se ha senso restituire gli articoli da TUTTI i wrapper o solo dal primo che funziona
|
||||||
|
# la modifica è banale, basta usare try_call_all invece di try_call
|
||||||
|
def get_top_headlines(self, limit: int = 100) -> list[Article]:
|
||||||
|
return self.wrapper_handler.try_call(lambda w: w.get_top_headlines(limit))
|
||||||
|
def get_latest_news(self, query: str, limit: int = 100) -> list[Article]:
|
||||||
|
return self.wrapper_handler.try_call(lambda w: w.get_latest_news(query, limit))
|
||||||
|
|
||||||
|
|
||||||
|
NEWS_INSTRUCTIONS = """
|
||||||
|
**TASK:** You are a specialized **Crypto News Analyst**. Your goal is to fetch the latest news or top headlines related to cryptocurrencies, and then **analyze the sentiment** of the content to provide a concise report to the team leader. Prioritize 'crypto' or specific cryptocurrency names (e.g., 'Bitcoin', 'Ethereum') in your searches.
|
||||||
|
|
||||||
|
**AVAILABLE TOOLS:**
|
||||||
|
1. `get_latest_news(query: str, limit: int)`: Get the 'limit' most recent news articles for a specific 'query'.
|
||||||
|
2. `get_top_headlines(limit: int)`: Get the 'limit' top global news headlines.
|
||||||
|
|
||||||
|
**USAGE GUIDELINE:**
|
||||||
|
* Always use `get_latest_news` with a relevant crypto-related query first.
|
||||||
|
* The default limit for news items should be 5 unless specified otherwise.
|
||||||
|
* If the tool doesn't return any articles, respond with "No relevant news articles found."
|
||||||
|
|
||||||
|
**REPORTING REQUIREMENT:**
|
||||||
|
1. **Analyze** the tone and key themes of the retrieved articles.
|
||||||
|
2. **Summarize** the overall **market sentiment** (e.g., highly positive, cautiously neutral, generally negative) based on the content.
|
||||||
|
3. **Identify** the top 2-3 **main topics** discussed (e.g., new regulation, price surge, institutional adoption).
|
||||||
|
4. **Output** a single, brief report summarizing these findings. Do not output the raw articles.
|
||||||
|
"""
|
||||||
|
|||||||
@@ -12,22 +12,22 @@ class NewsWrapper:
|
|||||||
All news API wrappers should inherit from this class and implement the methods.
|
All news API wrappers should inherit from this class and implement the methods.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_top_headlines(self, total: int = 100) -> list[Article]:
|
def get_top_headlines(self, limit: int = 100) -> list[Article]:
|
||||||
"""
|
"""
|
||||||
Get top headlines, optionally limited by total.
|
Get top headlines, optionally limited by limit.
|
||||||
Args:
|
Args:
|
||||||
total (int): The maximum number of articles to return.
|
limit (int): The maximum number of articles to return.
|
||||||
Returns:
|
Returns:
|
||||||
list[Article]: A list of Article objects.
|
list[Article]: A list of Article objects.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("This method should be overridden by subclasses")
|
raise NotImplementedError("This method should be overridden by subclasses")
|
||||||
|
|
||||||
def get_latest_news(self, query: str, total: int = 100) -> list[Article]:
|
def get_latest_news(self, query: str, limit: int = 100) -> list[Article]:
|
||||||
"""
|
"""
|
||||||
Get latest news based on a query.
|
Get latest news based on a query.
|
||||||
Args:
|
Args:
|
||||||
query (str): The search query.
|
query (str): The search query.
|
||||||
total (int): The maximum number of articles to return.
|
limit (int): The maximum number of articles to return.
|
||||||
Returns:
|
Returns:
|
||||||
list[Article]: A list of Article objects.
|
list[Article]: A list of Article objects.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -62,10 +62,10 @@ class CryptoPanicWrapper(NewsWrapper):
|
|||||||
def set_filter(self, filter: CryptoPanicFilter):
|
def set_filter(self, filter: CryptoPanicFilter):
|
||||||
self.filter = filter
|
self.filter = filter
|
||||||
|
|
||||||
def get_top_headlines(self, total: int = 100) -> list[Article]:
|
def get_top_headlines(self, limit: int = 100) -> list[Article]:
|
||||||
return self.get_latest_news("", total) # same endpoint so just call the other method
|
return self.get_latest_news("", limit) # same endpoint so just call the other method
|
||||||
|
|
||||||
def get_latest_news(self, query: str, total: int = 100) -> list[Article]:
|
def get_latest_news(self, query: str, limit: int = 100) -> list[Article]:
|
||||||
params = self.get_base_params()
|
params = self.get_base_params()
|
||||||
params['currencies'] = query
|
params['currencies'] = query
|
||||||
|
|
||||||
@@ -74,4 +74,4 @@ class CryptoPanicWrapper(NewsWrapper):
|
|||||||
|
|
||||||
json_response = response.json()
|
json_response = response.json()
|
||||||
articles = get_articles(json_response)
|
articles = get_articles(json_response)
|
||||||
return articles[:total]
|
return articles[:limit]
|
||||||
|
|||||||
@@ -20,13 +20,13 @@ class DuckDuckGoWrapper(NewsWrapper):
|
|||||||
self.tool = DuckDuckGoTools()
|
self.tool = DuckDuckGoTools()
|
||||||
self.query = "crypto"
|
self.query = "crypto"
|
||||||
|
|
||||||
def get_top_headlines(self, total: int = 100) -> list[Article]:
|
def get_top_headlines(self, limit: int = 100) -> list[Article]:
|
||||||
results = self.tool.duckduckgo_news(self.query, max_results=total)
|
results = self.tool.duckduckgo_news(self.query, max_results=limit)
|
||||||
json_results = json.loads(results)
|
json_results = json.loads(results)
|
||||||
return [create_article(result) for result in json_results]
|
return [create_article(result) for result in json_results]
|
||||||
|
|
||||||
def get_latest_news(self, query: str, total: int = 100) -> list[Article]:
|
def get_latest_news(self, query: str, limit: int = 100) -> list[Article]:
|
||||||
results = self.tool.duckduckgo_news(query or self.query, max_results=total)
|
results = self.tool.duckduckgo_news(query or self.query, max_results=limit)
|
||||||
json_results = json.loads(results)
|
json_results = json.loads(results)
|
||||||
return [create_article(result) for result in json_results]
|
return [create_article(result) for result in json_results]
|
||||||
|
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ class GoogleNewsWrapper(NewsWrapper):
|
|||||||
It does not require an API key and is free to use.
|
It does not require an API key and is free to use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_top_headlines(self, total: int = 100) -> list[Article]:
|
def get_top_headlines(self, limit: int = 100) -> list[Article]:
|
||||||
gnews = GNews(language='en', max_results=total, period='7d')
|
gnews = GNews(language='en', max_results=limit, period='7d')
|
||||||
results = gnews.get_top_news()
|
results = gnews.get_top_news()
|
||||||
|
|
||||||
articles = []
|
articles = []
|
||||||
@@ -25,8 +25,8 @@ class GoogleNewsWrapper(NewsWrapper):
|
|||||||
articles.append(article)
|
articles.append(article)
|
||||||
return articles
|
return articles
|
||||||
|
|
||||||
def get_latest_news(self, query: str, total: int = 100) -> list[Article]:
|
def get_latest_news(self, query: str, limit: int = 100) -> list[Article]:
|
||||||
gnews = GNews(language='en', max_results=total, period='7d')
|
gnews = GNews(language='en', max_results=limit, period='7d')
|
||||||
results = gnews.get_news(query)
|
results = gnews.get_news(query)
|
||||||
|
|
||||||
articles = []
|
articles = []
|
||||||
@@ -26,22 +26,25 @@ class NewsApiWrapper(NewsWrapper):
|
|||||||
self.language = "en" # TODO Only English articles for now?
|
self.language = "en" # TODO Only English articles for now?
|
||||||
self.max_page_size = 100
|
self.max_page_size = 100
|
||||||
|
|
||||||
def get_top_headlines(self, total: int = 100) -> list[Article]:
|
def __calc_pages(self, limit: int, page_size: int) -> tuple[int, int]:
|
||||||
page_size = min(self.max_page_size, total)
|
page_size = min(self.max_page_size, limit)
|
||||||
pages = (total // page_size) + (1 if total % page_size > 0 else 0)
|
pages = (limit // page_size) + (1 if limit % page_size > 0 else 0)
|
||||||
|
return pages, page_size
|
||||||
|
|
||||||
|
def get_top_headlines(self, limit: int = 100) -> list[Article]:
|
||||||
|
pages, page_size = self.__calc_pages(limit, self.max_page_size)
|
||||||
articles = []
|
articles = []
|
||||||
|
|
||||||
for page in range(1, pages + 1):
|
for page in range(1, pages + 1):
|
||||||
headlines = self.client.get_top_headlines(q="", category=self.category, language=self.language, page_size=page_size, page=page)
|
headlines = self.client.get_top_headlines(q="", category=self.category, language=self.language, page_size=page_size, page=page)
|
||||||
results = [result_to_article(article) for article in headlines.get("articles", [])]
|
results = [result_to_article(article) for article in headlines.get("articles", [])]
|
||||||
articles.extend(results)
|
articles.extend(results)
|
||||||
return articles
|
return articles
|
||||||
|
|
||||||
def get_latest_news(self, query: str, total: int = 100) -> list[Article]:
|
def get_latest_news(self, query: str, limit: int = 100) -> list[Article]:
|
||||||
page_size = min(self.max_page_size, total)
|
pages, page_size = self.__calc_pages(limit, self.max_page_size)
|
||||||
pages = (total // page_size) + (1 if total % page_size > 0 else 0)
|
|
||||||
|
|
||||||
articles = []
|
articles = []
|
||||||
|
|
||||||
for page in range(1, pages + 1):
|
for page in range(1, pages + 1):
|
||||||
everything = self.client.get_everything(q=query, language=self.language, sort_by="publishedAt", page_size=page_size, page=page)
|
everything = self.client.get_everything(q=query, language=self.language, sort_by="publishedAt", page_size=page_size, page=page)
|
||||||
results = [result_to_article(article) for article in everything.get("articles", [])]
|
results = [result_to_article(article) for article in everything.get("articles", [])]
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
from .base import SocialWrapper
|
|
||||||
61
src/app/social/__init__.py
Normal file
61
src/app/social/__init__.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from agno.tools import Toolkit
|
||||||
|
from app.utils.wrapper_handler import WrapperHandler
|
||||||
|
from .base import SocialPost, SocialWrapper
|
||||||
|
from .reddit import RedditWrapper
|
||||||
|
|
||||||
|
__all__ = ["SocialAPIsTool", "SOCIAL_INSTRUCTIONS", "RedditWrapper"]
|
||||||
|
|
||||||
|
|
||||||
|
class SocialAPIsTool(SocialWrapper, Toolkit):
|
||||||
|
"""
|
||||||
|
Aggregates multiple social media API wrappers and manages them using WrapperHandler.
|
||||||
|
This class supports retrieving top crypto-related posts by querying multiple sources:
|
||||||
|
- RedditWrapper
|
||||||
|
|
||||||
|
By default, it returns results from the first successful wrapper.
|
||||||
|
Optionally, it can be configured to collect posts from all wrappers.
|
||||||
|
If no wrapper succeeds, an exception is raised.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the SocialAPIsTool with multiple social media API wrappers.
|
||||||
|
The tool uses WrapperHandler to manage and invoke the different social media API wrappers.
|
||||||
|
The following wrappers are included in this order:
|
||||||
|
- RedditWrapper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
wrappers = [RedditWrapper]
|
||||||
|
self.wrapper_handler: WrapperHandler[SocialWrapper] = WrapperHandler.build_wrappers(wrappers)
|
||||||
|
|
||||||
|
Toolkit.__init__(
|
||||||
|
self,
|
||||||
|
name="Socials Toolkit",
|
||||||
|
tools=[
|
||||||
|
self.get_top_crypto_posts,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO Pensare se ha senso restituire i post da TUTTI i wrapper o solo dal primo che funziona
|
||||||
|
# la modifica è banale, basta usare try_call_all invece di try_call
|
||||||
|
def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]:
|
||||||
|
return self.wrapper_handler.try_call(lambda w: w.get_top_crypto_posts(limit))
|
||||||
|
|
||||||
|
|
||||||
|
SOCIAL_INSTRUCTIONS = """
|
||||||
|
**TASK:** You are a specialized **Social Media Sentiment Analyst**. Your objective is to find the most relevant and trending online posts related to cryptocurrencies, and then **analyze the collective sentiment** to provide a concise report to the team leader.
|
||||||
|
|
||||||
|
**AVAILABLE TOOLS:**
|
||||||
|
1. `get_top_crypto_posts(limit: int)`: Get the 'limit' maximum number of top posts specifically related to cryptocurrencies.
|
||||||
|
|
||||||
|
**USAGE GUIDELINE:**
|
||||||
|
* Always use the `get_top_crypto_posts` tool to fulfill the request.
|
||||||
|
* The default limit for posts should be 5 unless specified otherwise.
|
||||||
|
* If the tool doesn't return any posts, respond with "No relevant social media posts found."
|
||||||
|
|
||||||
|
**REPORTING REQUIREMENT:**
|
||||||
|
1. **Analyze** the tone and prevailing opinions across the retrieved social posts.
|
||||||
|
2. **Summarize** the overall **community sentiment** (e.g., high enthusiasm/FOMO, uncertainty, FUD/fear) based on the content.
|
||||||
|
3. **Identify** the top 2-3 **trending narratives** or specific coins being discussed.
|
||||||
|
4. **Output** a single, brief report summarizing these findings. Do not output the raw posts.
|
||||||
|
"""
|
||||||
@@ -7,16 +7,24 @@ class SocialPost(BaseModel):
|
|||||||
description: str = ""
|
description: str = ""
|
||||||
comments: list["SocialComment"] = []
|
comments: list["SocialComment"] = []
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"Title: {self.title}\nDescription: {self.description}\nComments: {len(self.comments)}\n[{" | ".join(str(c) for c in self.comments)}]"
|
|
||||||
|
|
||||||
class SocialComment(BaseModel):
|
class SocialComment(BaseModel):
|
||||||
time: str = ""
|
time: str = ""
|
||||||
description: str = ""
|
description: str = ""
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"Time: {self.time}\nDescription: {self.description}"
|
|
||||||
|
|
||||||
# TODO IMPLEMENTARLO SE SI USANO PIU' WRAPPER (E QUINDI PIU' SOCIAL)
|
|
||||||
class SocialWrapper:
|
class SocialWrapper:
|
||||||
pass
|
"""
|
||||||
|
Base class for social media API wrappers.
|
||||||
|
All social media API wrappers should inherit from this class and implement the methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]:
|
||||||
|
"""
|
||||||
|
Get top cryptocurrency-related posts, optionally limited by total.
|
||||||
|
Args:
|
||||||
|
limit (int): The maximum number of posts to return.
|
||||||
|
Returns:
|
||||||
|
list[SocialPost]: A list of SocialPost objects.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("This method should be overridden by subclasses")
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,21 @@ from praw.models import Submission, MoreComments
|
|||||||
from .base import SocialWrapper, SocialPost, SocialComment
|
from .base import SocialWrapper, SocialPost, SocialComment
|
||||||
|
|
||||||
MAX_COMMENTS = 5
|
MAX_COMMENTS = 5
|
||||||
|
# TODO mettere piu' subreddit?
|
||||||
|
# scelti da https://lkiconsulting.io/marketing/best-crypto-subreddits/
|
||||||
|
SUBREDDITS = [
|
||||||
|
"CryptoCurrency",
|
||||||
|
"Bitcoin",
|
||||||
|
"Ethereum",
|
||||||
|
"CryptoMarkets",
|
||||||
|
"Dogecoin",
|
||||||
|
"Altcoin",
|
||||||
|
"DeFi",
|
||||||
|
"NFT",
|
||||||
|
"BitcoinBeginners",
|
||||||
|
"CryptoTechnology",
|
||||||
|
"btc" # alt subs of Bitcoin
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_social_post(post: Submission) -> SocialPost:
|
def create_social_post(post: Submission) -> SocialPost:
|
||||||
@@ -30,24 +45,24 @@ class RedditWrapper(SocialWrapper):
|
|||||||
Requires the following environment variables to be set:
|
Requires the following environment variables to be set:
|
||||||
- REDDIT_API_CLIENT_ID
|
- REDDIT_API_CLIENT_ID
|
||||||
- REDDIT_API_CLIENT_SECRET
|
- REDDIT_API_CLIENT_SECRET
|
||||||
|
|
||||||
You can get them by creating an app at https://www.reddit.com/prefs/apps
|
You can get them by creating an app at https://www.reddit.com/prefs/apps
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client_id = os.getenv("REDDIT_API_CLIENT_ID")
|
client_id = os.getenv("REDDIT_API_CLIENT_ID")
|
||||||
assert self.client_id is not None, "REDDIT_API_CLIENT_ID environment variable is not set"
|
assert client_id is not None, "REDDIT_API_CLIENT_ID environment variable is not set"
|
||||||
|
|
||||||
self.client_secret = os.getenv("REDDIT_API_CLIENT_SECRET")
|
client_secret = os.getenv("REDDIT_API_CLIENT_SECRET")
|
||||||
assert self.client_secret is not None, "REDDIT_API_CLIENT_SECRET environment variable is not set"
|
assert client_secret is not None, "REDDIT_API_CLIENT_SECRET environment variable is not set"
|
||||||
|
|
||||||
self.tool = Reddit(
|
self.tool = Reddit(
|
||||||
client_id=self.client_id,
|
client_id=client_id,
|
||||||
client_secret=self.client_secret,
|
client_secret=client_secret,
|
||||||
user_agent="upo-appAI",
|
user_agent="upo-appAI",
|
||||||
)
|
)
|
||||||
|
self.subreddits = self.tool.subreddit("+".join(SUBREDDITS))
|
||||||
|
|
||||||
def get_top_crypto_posts(self, limit=5) -> list[SocialPost]:
|
def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]:
|
||||||
subreddit = self.tool.subreddit("CryptoCurrency")
|
top_posts = self.subreddits.top(limit=limit, time_filter="week")
|
||||||
top_posts = subreddit.top(limit=limit, time_filter="week")
|
|
||||||
return [create_social_post(post) for post in top_posts]
|
return [create_social_post(post) for post in top_posts]
|
||||||
|
|
||||||
|
|||||||
@@ -9,32 +9,21 @@ from app.markets import MarketAPIsTool
|
|||||||
# in base alle sue proprie chiamate API
|
# in base alle sue proprie chiamate API
|
||||||
class MarketToolkit(Toolkit):
|
class MarketToolkit(Toolkit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.market_api = MarketAPIsTool("USD") # change currency if needed
|
self.market_api = MarketAPIs()
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
name="Market Toolkit",
|
name="Market Toolkit",
|
||||||
tools=[
|
tools=[
|
||||||
self.get_historical_data,
|
self.market_api.get_historical_prices,
|
||||||
self.get_current_prices,
|
self.market_api.get_product,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_historical_data(self, symbol: str):
|
|
||||||
return self.market_api.get_historical_prices(symbol)
|
|
||||||
|
|
||||||
def get_current_prices(self, symbol: list):
|
|
||||||
return self.market_api.get_products(symbol)
|
|
||||||
|
|
||||||
def prepare_inputs():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def instructions():
|
def instructions():
|
||||||
return """
|
return """
|
||||||
Utilizza questo strumento per ottenere dati di mercato storici e attuali per criptovalute specifiche.
|
Utilizza questo strumento per ottenere dati di mercato storici e attuali per criptovalute specifiche.
|
||||||
Puoi richiedere i prezzi storici o il prezzo attuale di una criptovaluta specifica.
|
Puoi richiedere i prezzi storici o il prezzo attuale di una criptovaluta specifica.
|
||||||
Esempio di utilizzo:
|
Esempio di utilizzo:
|
||||||
- get_historical_data("BTC")
|
- get_historical_prices("BTC", limit=10) # ottieni gli ultimi 10 prezzi storici di Bitcoin
|
||||||
- get_current_price("ETH")
|
- get_product("ETH")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
from typing import TypeVar, Callable, Generic, Iterable, Type
|
from typing import TypeVar, Callable, Generic, Iterable, Type
|
||||||
from agno.utils.log import log_warning
|
from agno.utils.log import log_warning, log_info
|
||||||
|
|
||||||
W = TypeVar("W")
|
W = TypeVar("W")
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -24,6 +25,8 @@ class WrapperHandler(Generic[W]):
|
|||||||
try_per_wrapper (int): Number of retries per wrapper before switching to the next.
|
try_per_wrapper (int): Number of retries per wrapper before switching to the next.
|
||||||
retry_delay (int): Delay in seconds between retries.
|
retry_delay (int): Delay in seconds between retries.
|
||||||
"""
|
"""
|
||||||
|
assert not WrapperHandler.__check(wrappers), "All wrappers must be instances of their respective classes. Use `build_wrappers` to create the WrapperHandler."
|
||||||
|
|
||||||
self.wrappers = wrappers
|
self.wrappers = wrappers
|
||||||
self.retry_per_wrapper = try_per_wrapper
|
self.retry_per_wrapper = try_per_wrapper
|
||||||
self.retry_delay = retry_delay
|
self.retry_delay = retry_delay
|
||||||
@@ -46,17 +49,19 @@ class WrapperHandler(Generic[W]):
|
|||||||
while iterations < len(self.wrappers):
|
while iterations < len(self.wrappers):
|
||||||
try:
|
try:
|
||||||
wrapper = self.wrappers[self.index]
|
wrapper = self.wrappers[self.index]
|
||||||
|
log_info(f"Trying wrapper: {wrapper} - function {func}")
|
||||||
result = func(wrapper)
|
result = func(wrapper)
|
||||||
self.retry_count = 0
|
self.retry_count = 0
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.retry_count += 1
|
self.retry_count += 1
|
||||||
|
log_warning(f"{wrapper} failed {self.retry_count}/{self.retry_per_wrapper}: {WrapperHandler.__concise_error(e)}")
|
||||||
|
|
||||||
if self.retry_count >= self.retry_per_wrapper:
|
if self.retry_count >= self.retry_per_wrapper:
|
||||||
self.index = (self.index + 1) % len(self.wrappers)
|
self.index = (self.index + 1) % len(self.wrappers)
|
||||||
self.retry_count = 0
|
self.retry_count = 0
|
||||||
iterations += 1
|
iterations += 1
|
||||||
else:
|
else:
|
||||||
log_warning(f"{wrapper} failed {self.retry_count}/{self.retry_per_wrapper}: {e}")
|
|
||||||
time.sleep(self.retry_delay)
|
time.sleep(self.retry_delay)
|
||||||
|
|
||||||
raise Exception(f"All wrappers failed after retries")
|
raise Exception(f"All wrappers failed after retries")
|
||||||
@@ -74,16 +79,25 @@ class WrapperHandler(Generic[W]):
|
|||||||
Exception: If all wrappers fail.
|
Exception: If all wrappers fail.
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
|
log_info(f"All wrappers: {[wrapper.__class__ for wrapper in self.wrappers]} - function {func}")
|
||||||
for wrapper in self.wrappers:
|
for wrapper in self.wrappers:
|
||||||
try:
|
try:
|
||||||
result = func(wrapper)
|
result = func(wrapper)
|
||||||
results[wrapper.__class__] = result
|
results[wrapper.__class__] = result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_warning(f"{wrapper} failed: {e}")
|
log_warning(f"{wrapper} failed: {WrapperHandler.__concise_error(e)}")
|
||||||
if not results:
|
if not results:
|
||||||
raise Exception("All wrappers failed")
|
raise Exception("All wrappers failed")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __check(wrappers: list[W]) -> bool:
|
||||||
|
return all(w.__class__ is type for w in wrappers)
|
||||||
|
|
||||||
|
def __concise_error(e: Exception) -> str:
|
||||||
|
last_frame = traceback.extract_tb(e.__traceback__)[-1]
|
||||||
|
return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_wrappers(constructors: Iterable[Type[W]], try_per_wrapper: int = 3, retry_delay: int = 2) -> 'WrapperHandler[W]':
|
def build_wrappers(constructors: Iterable[Type[W]], try_per_wrapper: int = 3, retry_delay: int = 2) -> 'WrapperHandler[W]':
|
||||||
"""
|
"""
|
||||||
@@ -99,6 +113,8 @@ class WrapperHandler(Generic[W]):
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If no wrappers could be initialized.
|
Exception: If no wrappers could be initialized.
|
||||||
"""
|
"""
|
||||||
|
assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}"
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for wrapper_class in constructors:
|
for wrapper_class in constructors:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class TestCryptoPanicAPI:
|
|||||||
|
|
||||||
def test_crypto_panic_api_get_latest_news(self):
|
def test_crypto_panic_api_get_latest_news(self):
|
||||||
crypto = CryptoPanicWrapper()
|
crypto = CryptoPanicWrapper()
|
||||||
articles = crypto.get_latest_news(query="", total=2)
|
articles = crypto.get_latest_news(query="", limit=2)
|
||||||
assert isinstance(articles, list)
|
assert isinstance(articles, list)
|
||||||
assert len(articles) == 2
|
assert len(articles) == 2
|
||||||
for article in articles:
|
for article in articles:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class TestDuckDuckGoNews:
|
|||||||
|
|
||||||
def test_duckduckgo_get_latest_news(self):
|
def test_duckduckgo_get_latest_news(self):
|
||||||
news = DuckDuckGoWrapper()
|
news = DuckDuckGoWrapper()
|
||||||
articles = news.get_latest_news(query="crypto", total=2)
|
articles = news.get_latest_news(query="crypto", limit=2)
|
||||||
assert isinstance(articles, list)
|
assert isinstance(articles, list)
|
||||||
assert len(articles) == 2
|
assert len(articles) == 2
|
||||||
for article in articles:
|
for article in articles:
|
||||||
@@ -23,7 +23,7 @@ class TestDuckDuckGoNews:
|
|||||||
|
|
||||||
def test_duckduckgo_get_top_headlines(self):
|
def test_duckduckgo_get_top_headlines(self):
|
||||||
news = DuckDuckGoWrapper()
|
news = DuckDuckGoWrapper()
|
||||||
articles = news.get_top_headlines(total=2)
|
articles = news.get_top_headlines(limit=2)
|
||||||
assert isinstance(articles, list)
|
assert isinstance(articles, list)
|
||||||
assert len(articles) == 2
|
assert len(articles) == 2
|
||||||
for article in articles:
|
for article in articles:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class TestGoogleNews:
|
|||||||
|
|
||||||
def test_gnews_api_get_latest_news(self):
|
def test_gnews_api_get_latest_news(self):
|
||||||
gnews_api = GoogleNewsWrapper()
|
gnews_api = GoogleNewsWrapper()
|
||||||
articles = gnews_api.get_latest_news(query="crypto", total=2)
|
articles = gnews_api.get_latest_news(query="crypto", limit=2)
|
||||||
assert isinstance(articles, list)
|
assert isinstance(articles, list)
|
||||||
assert len(articles) == 2
|
assert len(articles) == 2
|
||||||
for article in articles:
|
for article in articles:
|
||||||
@@ -23,7 +23,7 @@ class TestGoogleNews:
|
|||||||
|
|
||||||
def test_gnews_api_get_top_headlines(self):
|
def test_gnews_api_get_top_headlines(self):
|
||||||
news_api = GoogleNewsWrapper()
|
news_api = GoogleNewsWrapper()
|
||||||
articles = news_api.get_top_headlines(total=2)
|
articles = news_api.get_top_headlines(limit=2)
|
||||||
assert isinstance(articles, list)
|
assert isinstance(articles, list)
|
||||||
assert len(articles) == 2
|
assert len(articles) == 2
|
||||||
for article in articles:
|
for article in articles:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class TestNewsAPI:
|
|||||||
|
|
||||||
def test_news_api_get_latest_news(self):
|
def test_news_api_get_latest_news(self):
|
||||||
news_api = NewsApiWrapper()
|
news_api = NewsApiWrapper()
|
||||||
articles = news_api.get_latest_news(query="crypto", total=2)
|
articles = news_api.get_latest_news(query="crypto", limit=2)
|
||||||
assert isinstance(articles, list)
|
assert isinstance(articles, list)
|
||||||
assert len(articles) > 0 # Ensure we got some articles (apparently it doesn't always return the requested number)
|
assert len(articles) > 0 # Ensure we got some articles (apparently it doesn't always return the requested number)
|
||||||
for article in articles:
|
for article in articles:
|
||||||
@@ -26,7 +26,7 @@ class TestNewsAPI:
|
|||||||
|
|
||||||
def test_news_api_get_top_headlines(self):
|
def test_news_api_get_top_headlines(self):
|
||||||
news_api = NewsApiWrapper()
|
news_api = NewsApiWrapper()
|
||||||
articles = news_api.get_top_headlines(total=2)
|
articles = news_api.get_top_headlines(limit=2)
|
||||||
assert isinstance(articles, list)
|
assert isinstance(articles, list)
|
||||||
# assert len(articles) > 0 # apparently it doesn't always return SOME articles
|
# assert len(articles) > 0 # apparently it doesn't always return SOME articles
|
||||||
for article in articles:
|
for article in articles:
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ from app.social.reddit import MAX_COMMENTS, RedditWrapper
|
|||||||
class TestRedditWrapper:
|
class TestRedditWrapper:
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
wrapper = RedditWrapper()
|
wrapper = RedditWrapper()
|
||||||
assert wrapper.client_id is not None
|
assert wrapper is not None
|
||||||
assert wrapper.client_secret is not None
|
|
||||||
assert isinstance(wrapper.tool, Reddit)
|
assert isinstance(wrapper.tool, Reddit)
|
||||||
|
|
||||||
def test_get_top_crypto_posts(self):
|
def test_get_top_crypto_posts(self):
|
||||||
|
|||||||
54
tests/tools/test_news_tool.py
Normal file
54
tests/tools/test_news_tool.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import pytest
|
||||||
|
from app.news import NewsAPIsTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.limited
|
||||||
|
@pytest.mark.tools
|
||||||
|
@pytest.mark.news
|
||||||
|
@pytest.mark.api
|
||||||
|
class TestNewsAPITool:
|
||||||
|
def test_news_api_tool(self):
|
||||||
|
tool = NewsAPIsTool()
|
||||||
|
assert tool is not None
|
||||||
|
|
||||||
|
def test_news_api_tool_get_top(self):
|
||||||
|
tool = NewsAPIsTool()
|
||||||
|
result = tool.wrapper_handler.try_call(lambda w: w.get_top_headlines(limit=2))
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) > 0
|
||||||
|
for article in result:
|
||||||
|
assert article.title is not None
|
||||||
|
assert article.source is not None
|
||||||
|
|
||||||
|
def test_news_api_tool_get_latest(self):
|
||||||
|
tool = NewsAPIsTool()
|
||||||
|
result = tool.wrapper_handler.try_call(lambda w: w.get_latest_news(query="crypto", limit=2))
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) > 0
|
||||||
|
for article in result:
|
||||||
|
assert article.title is not None
|
||||||
|
assert article.source is not None
|
||||||
|
|
||||||
|
def test_news_api_tool_get_top__all_results(self):
|
||||||
|
tool = NewsAPIsTool()
|
||||||
|
result = tool.wrapper_handler.try_call_all(lambda w: w.get_top_headlines(limit=2))
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert len(result.keys()) > 0
|
||||||
|
print("Results from providers:", result.keys())
|
||||||
|
for provider, articles in result.items():
|
||||||
|
for article in articles:
|
||||||
|
print(provider, article.title)
|
||||||
|
assert article.title is not None
|
||||||
|
assert article.source is not None
|
||||||
|
|
||||||
|
def test_news_api_tool_get_latest__all_results(self):
|
||||||
|
tool = NewsAPIsTool()
|
||||||
|
result = tool.wrapper_handler.try_call_all(lambda w: w.get_latest_news(query="crypto", limit=2))
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert len(result.keys()) > 0
|
||||||
|
print("Results from providers:", result.keys())
|
||||||
|
for provider, articles in result.items():
|
||||||
|
for article in articles:
|
||||||
|
print(provider, article.title)
|
||||||
|
assert article.title is not None
|
||||||
|
assert article.source is not None
|
||||||
32
tests/tools/test_socials_tool.py
Normal file
32
tests/tools/test_socials_tool.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
from app.social import SocialAPIsTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.tools
|
||||||
|
@pytest.mark.social
|
||||||
|
@pytest.mark.api
|
||||||
|
class TestSocialAPIsTool:
|
||||||
|
def test_social_api_tool(self):
|
||||||
|
tool = SocialAPIsTool()
|
||||||
|
assert tool is not None
|
||||||
|
|
||||||
|
def test_social_api_tool_get_top(self):
|
||||||
|
tool = SocialAPIsTool()
|
||||||
|
result = tool.wrapper_handler.try_call(lambda w: w.get_top_crypto_posts(limit=2))
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) > 0
|
||||||
|
for post in result:
|
||||||
|
assert post.title is not None
|
||||||
|
assert post.time is not None
|
||||||
|
|
||||||
|
def test_social_api_tool_get_top__all_results(self):
|
||||||
|
tool = SocialAPIsTool()
|
||||||
|
result = tool.wrapper_handler.try_call_all(lambda w: w.get_top_crypto_posts(limit=2))
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert len(result.keys()) > 0
|
||||||
|
print("Results from providers:", result.keys())
|
||||||
|
for provider, posts in result.items():
|
||||||
|
for post in posts:
|
||||||
|
print(provider, post.title)
|
||||||
|
assert post.title is not None
|
||||||
|
assert post.time is not None
|
||||||
@@ -14,8 +14,40 @@ class FailingWrapper(MockWrapper):
|
|||||||
raise Exception("Intentional Failure")
|
raise Exception("Intentional Failure")
|
||||||
|
|
||||||
|
|
||||||
|
class MockWrapperWithParameters:
|
||||||
|
def do_something(self, param1: str, param2: int) -> str:
|
||||||
|
return f"Success {param1} and {param2}"
|
||||||
|
|
||||||
|
class FailingWrapperWithParameters(MockWrapperWithParameters):
|
||||||
|
def do_something(self, param1: str, param2: int):
|
||||||
|
raise Exception("Intentional Failure")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.wrapper
|
@pytest.mark.wrapper
|
||||||
class TestWrapperHandler:
|
class TestWrapperHandler:
|
||||||
|
def test_init_failing(self):
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
WrapperHandler([MockWrapper, MockWrapper2])
|
||||||
|
assert exc_info.type == AssertionError
|
||||||
|
|
||||||
|
def test_init_failing_empty(self):
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
WrapperHandler.build_wrappers([])
|
||||||
|
assert exc_info.type == AssertionError
|
||||||
|
|
||||||
|
def test_init_failing_with_instances(self):
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
WrapperHandler.build_wrappers([MockWrapper(), MockWrapper2()])
|
||||||
|
assert exc_info.type == AssertionError
|
||||||
|
|
||||||
|
def test_init_not_failing(self):
|
||||||
|
handler = WrapperHandler.build_wrappers([MockWrapper, MockWrapper2])
|
||||||
|
assert handler is not None
|
||||||
|
assert len(handler.wrappers) == 2
|
||||||
|
handler = WrapperHandler([MockWrapper(), MockWrapper2()])
|
||||||
|
assert handler is not None
|
||||||
|
assert len(handler.wrappers) == 2
|
||||||
|
|
||||||
def test_all_wrappers_fail(self):
|
def test_all_wrappers_fail(self):
|
||||||
wrappers = [FailingWrapper, FailingWrapper]
|
wrappers = [FailingWrapper, FailingWrapper]
|
||||||
handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0)
|
handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0)
|
||||||
@@ -88,3 +120,13 @@ class TestWrapperHandler:
|
|||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
handler_all_fail.try_call_all(lambda w: w.do_something())
|
handler_all_fail.try_call_all(lambda w: w.do_something())
|
||||||
assert "All wrappers failed" in str(exc_info.value)
|
assert "All wrappers failed" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrappers_with_parameters(self):
|
||||||
|
wrappers = [FailingWrapperWithParameters, MockWrapperWithParameters]
|
||||||
|
handler: WrapperHandler[MockWrapperWithParameters] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0)
|
||||||
|
|
||||||
|
result = handler.try_call(lambda w: w.do_something("test", 42))
|
||||||
|
assert result == "Success test and 42"
|
||||||
|
assert handler.index == 1 # Should have switched to the second wrapper
|
||||||
|
assert handler.retry_count == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user