diff --git a/Dockerfile b/Dockerfile index 16868ac..f4d7e97 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,23 +1,26 @@ -# Vogliamo usare una versione di linux leggera con già uv installato -# Infatti scegliamo l'immagine ufficiale di uv che ha già tutto configurato -FROM ghcr.io/astral-sh/uv:python3.12-alpine +# Utilizziamo Debian slim invece di Alpine per migliore compatibilità +FROM debian:bookworm-slim +RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/* -# Dopo aver definito la workdir mi trovo già in essa -WORKDIR /app +# Installiamo uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh +ENV PATH="/root/.local/bin:$PATH" -# Settiamo variabili d'ambiente per usare python del sistema invece che venv -ENV UV_PROJECT_ENVIRONMENT=/usr/local +# Configuriamo UV per usare copy mode ed evitare problemi di linking ENV UV_LINK_MODE=copy -# Copiamo prima i file di configurazione delle dipendenze e installiamo le dipendenze +# Impostiamo la directory di lavoro +WORKDIR /app + +# Copiamo i file del progetto COPY pyproject.toml ./ COPY uv.lock ./ -RUN uv sync --frozen --no-cache +COPY LICENSE ./ +COPY src/ ./src/ -# Copiamo i file sorgente dopo aver installato le dipendenze per sfruttare la cache di Docker -COPY LICENSE . -COPY src ./src +# Creiamo l'ambiente virtuale con tutto già presente +RUN uv sync +ENV PYTHONPATH="/app/src" -# Comando di default all'avvio dell'applicazione -CMD ["echo", "Benvenuto in UPO AppAI!"] -CMD ["uv", "run", "src/app.py"] +# Comando di avvio dell'applicazione +CMD ["uv", "run", "src/app"] diff --git a/README.md b/README.md index a545c92..aae9a60 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,13 @@ L'obiettivo è quello di creare un sistema di consulenza finanziaria basato su L # **Indice** - [Installazione](#installazione) - - [1. Variabili d'Ambiente](#1-variabili-dambiente) - - [2. Ollama](#2-ollama) - - [3. Docker](#3-docker) - - [4. UV (solo per sviluppo locale)](#4-uv-solo-per-sviluppo-locale) + - [1. Variabili d'Ambiente](#1-variabili-dambiente) + - [2. Ollama](#2-ollama) + - [3. Docker](#3-docker) + - [4. UV (solo per sviluppo locale)](#4-uv-solo-per-sviluppo-locale) - [Applicazione](#applicazione) - - [Ultimo Aggiornamento](#ultimo-aggiornamento) - - [Tests](#tests) + - [Struttura del codice del Progetto](#struttura-del-codice-del-progetto) + - [Tests](#tests) # **Installazione** @@ -31,9 +31,10 @@ L'installazione di questo progetto richiede 3 passaggi totali (+1 se si vuole sv ### **1. Variabili d'Ambiente** -Copia il file `.env.example` in `.env` e modificalo con le tue API keys: +Copia il file `.env.example` in `.env` e successivamente modificalo con le tue API keys: ```sh cp .env.example .env +nano .env # esempio di modifica del file ``` Le API Keys devono essere inserite nelle variabili opportune dopo l'uguale e ***senza*** spazi. Esse si possono ottenere tramite i loro providers (alcune sono gratuite, altre a pagamento).\ @@ -58,11 +59,6 @@ I modelli usati dall'applicazione sono visibili in [src/app/models.py](src/app/m Se si vuole solamente avviare il progetto, si consiglia di utilizzare [Docker](https://www.docker.com), dato che sono stati creati i files [Dockerfile](Dockerfile) e [docker-compose.yaml](docker-compose.yaml) per creare il container con tutti i file necessari e già in esecuzione. ```sh -# Configura le variabili d'ambiente -cp .env.example .env -nano .env # Modifica il file - -# Avvia il container docker compose up --build -d ``` @@ -80,16 +76,17 @@ powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | ie curl -LsSf https://astral.sh/uv/install.sh | sh ``` -UV installerà python e creerà automaticamente l'ambiente virtuale con le dipendenze corrette (nota che questo passaggio è opzionale, dato che uv, ogni volta che si esegue un comando, controlla se l'ambiente è attivo e se le dipendenze sono installate): +Dopodiché bisogna creare un ambiente virtuale per lo sviluppo locale e impostare PYTHONPATH. Questo passaggio è necessario per far sì che Python riesca a trovare tutti i moduli del progetto e ad installare tutte le dipendenze. Fortunatamente uv semplifica molto questo processo: ```sh -uv sync --frozen --no-cache +uv venv +uv pip install -e . ``` -A questo punto si può far partire il progetto tramite il comando: +A questo punto si può già modificare il codice e, quando necessario, far partire il progetto tramite il comando: ```sh -uv run python src/app.py +uv run python src/app ``` # **Applicazione** @@ -102,6 +99,20 @@ Usando la libreria ``gradio`` è stata creata un'interfaccia web semplice per in - **Social Agent**: Analizza i sentimenti sui social media riguardo alle criptovalute. - **Predictor Agent**: Utilizza i dati raccolti dagli altri agenti per fare previsioni. +## Struttura del codice del Progetto + +``` +src +└── app + ├── __main__.py + ├── agents <-- Agenti, modelli, prompts e simili + ├── base <-- Classi base per le API + ├── markets <-- Market data provider (Es. Binance) + ├── news <-- News data provider (Es. NewsAPI) + ├── social <-- Social data provider (Es. Reddit) + └── utils <-- Codice di utilità generale +``` + ## Tests Per eseguire i test, assicurati di aver configurato correttamente le variabili d'ambiente nel file `.env` come descritto sopra. Poi esegui il comando: diff --git a/demos/example.py b/demos/example.py index c1fa08c..35acf59 100644 --- a/demos/example.py +++ b/demos/example.py @@ -14,7 +14,7 @@ try: instructions="Use tables to display data.", markdown=True, ) - result = reasoning_agent.run("Scrivi una poesia su un gatto. Sii breve.") + result = reasoning_agent.run("Scrivi una poesia su un gatto. Sii breve.") # type: ignore print(result.content) except Exception as e: print(f"Si è verificato un errore: {e}") diff --git a/demos/market_providers_api_demo.py b/demos/market_providers_api_demo.py index fc05c26..caba571 100644 --- a/demos/market_providers_api_demo.py +++ b/demos/market_providers_api_demo.py @@ -32,7 +32,7 @@ from app.markets import ( CryptoCompareWrapper, BinanceWrapper, YFinanceWrapper, - BaseWrapper + MarketWrapper ) # Carica variabili d'ambiente @@ -40,21 +40,21 @@ load_dotenv() class DemoFormatter: """Classe per formattare l'output del demo in modo strutturato.""" - + @staticmethod def print_header(title: str, char: str = "=", width: int = 80): """Stampa un'intestazione formattata.""" print(f"\n{char * width}") print(f"{title:^{width}}") print(f"{char * width}") - + @staticmethod def print_subheader(title: str, char: str = "-", width: int = 60): """Stampa una sotto-intestazione formattata.""" print(f"\n{char * width}") print(f" {title}") print(f"{char * width}") - + @staticmethod def print_request_info(provider_name: str, method: str, timestamp: datetime, status: str, error: Optional[str] = None): @@ -66,83 +66,83 @@ class DemoFormatter: if error: print(f"❌ Error: {error}") print() - + @staticmethod def print_product_table(products: List[Any], title: str = "Products"): """Stampa una tabella di prodotti.""" if not products: print(f"📋 {title}: Nessun prodotto trovato") return - + print(f"📋 {title} ({len(products)} items):") print(f"{'Symbol':<15} {'ID':<20} {'Price':<12} {'Quote':<10} {'Status':<10}") print("-" * 67) - + for product in products[:10]: # Mostra solo i primi 10 symbol = getattr(product, 'symbol', 'N/A') product_id = getattr(product, 'id', 'N/A') price = getattr(product, 'price', 0.0) quote = getattr(product, 'quote_currency', 'N/A') status = getattr(product, 'status', 'N/A') - + # Tronca l'ID se troppo lungo if len(product_id) > 18: product_id = product_id[:15] + "..." - + price_str = f"${price:.2f}" if price > 0 else "N/A" - + print(f"{symbol:<15} {product_id:<20} {price_str:<12} {quote:<10} {status:<10}") - + if len(products) > 10: print(f"... e altri {len(products) - 10} prodotti") print() - + @staticmethod def print_prices_table(prices: List[Any], title: str = "Historical Prices"): """Stampa una tabella di prezzi storici.""" if not prices: print(f"💰 {title}: Nessun prezzo trovato") return - + print(f"💰 {title} ({len(prices)} entries):") print(f"{'Time':<12} {'Open':<12} {'High':<12} {'Low':<12} {'Close':<12} {'Volume':<15}") print("-" * 75) - + for price in prices[:5]: # Mostra solo i primi 5 time_str = getattr(price, 'time', 'N/A') # Il time è già una stringa, non serve strftime if len(time_str) > 10: time_str = time_str[:10] # Tronca se troppo lungo - + open_price = f"${getattr(price, 'open', 0):.2f}" high_price = f"${getattr(price, 'high', 0):.2f}" low_price = f"${getattr(price, 'low', 0):.2f}" close_price = f"${getattr(price, 'close', 0):.2f}" volume = f"{getattr(price, 'volume', 0):,.0f}" - + print(f"{time_str:<12} {open_price:<12} {high_price:<12} {low_price:<12} {close_price:<12} {volume:<15}") - + if len(prices) > 5: print(f"... e altri {len(prices) - 5} prezzi") print() class ProviderTester: """Classe per testare i provider di market data.""" - + def __init__(self): self.formatter = DemoFormatter() self.test_symbols = ["BTC", "ETH", "ADA"] - - def test_provider(self, wrapper: BaseWrapper, provider_name: str) -> Dict[str, Any]: + + def test_provider(self, wrapper: MarketWrapper, provider_name: str) -> Dict[str, Any]: """Testa un provider specifico con tutti i metodi disponibili.""" - results = { + results: Dict[str, Any] = { "provider_name": provider_name, "tests": {}, "overall_status": "SUCCESS" } - + self.formatter.print_subheader(f"🔍 Testing {provider_name}") - + # Test get_product for symbol in self.test_symbols: timestamp = datetime.now() @@ -153,13 +153,13 @@ class ProviderTester: ) if product: print(f"📦 Product: {product.symbol} (ID: {product.id})") - print(f" Price: ${product.price:.2f}, Quote: {product.quote_currency}") + print(f" Price: ${product.price:.2f}, Quote: {product.currency}") print(f" Volume 24h: {product.volume_24h:,.2f}") else: print(f"📦 Product: Nessun prodotto trovato per {symbol}") - + results["tests"][f"get_product_{symbol}"] = "SUCCESS" - + except Exception as e: error_msg = str(e) self.formatter.print_request_info( @@ -167,7 +167,7 @@ class ProviderTester: ) results["tests"][f"get_product_{symbol}"] = f"ERROR: {error_msg}" results["overall_status"] = "PARTIAL" - + # Test get_products timestamp = datetime.now() try: @@ -177,7 +177,7 @@ class ProviderTester: ) self.formatter.print_product_table(products, f"{provider_name} Products") results["tests"]["get_products"] = "SUCCESS" - + except Exception as e: error_msg = str(e) self.formatter.print_request_info( @@ -185,7 +185,7 @@ class ProviderTester: ) results["tests"]["get_products"] = f"ERROR: {error_msg}" results["overall_status"] = "PARTIAL" - + # Test get_historical_prices timestamp = datetime.now() try: @@ -195,7 +195,7 @@ class ProviderTester: ) self.formatter.print_prices_table(prices, f"{provider_name} BTC Historical Prices") results["tests"]["get_historical_prices"] = "SUCCESS" - + except Exception as e: error_msg = str(e) self.formatter.print_request_info( @@ -203,7 +203,7 @@ class ProviderTester: ) results["tests"]["get_historical_prices"] = f"ERROR: {error_msg}" results["overall_status"] = "PARTIAL" - + return results def check_environment_variables() -> Dict[str, bool]: @@ -217,11 +217,11 @@ def check_environment_variables() -> Dict[str, bool]: } return env_vars -def initialize_providers() -> Dict[str, BaseWrapper]: +def initialize_providers() -> Dict[str, MarketWrapper]: """Inizializza tutti i provider disponibili.""" - providers = {} + providers: Dict[str, MarketWrapper] = {} env_vars = check_environment_variables() - + # CryptoCompareWrapper if env_vars["CRYPTOCOMPARE_API_KEY"]: try: @@ -231,7 +231,7 @@ def initialize_providers() -> Dict[str, BaseWrapper]: print(f"❌ Errore nell'inizializzazione di CryptoCompareWrapper: {e}") else: print("⚠️ CryptoCompareWrapper saltato: CRYPTOCOMPARE_API_KEY non trovata") - + # CoinBaseWrapper if env_vars["COINBASE_API_KEY"] and env_vars["COINBASE_API_SECRET"]: try: @@ -241,14 +241,14 @@ def initialize_providers() -> Dict[str, BaseWrapper]: print(f"❌ Errore nell'inizializzazione di CoinBaseWrapper: {e}") else: print("⚠️ CoinBaseWrapper saltato: credenziali Coinbase non complete") - + # BinanceWrapper try: providers["Binance"] = BinanceWrapper() print("✅ BinanceWrapper inizializzato con successo") except Exception as e: print(f"❌ Errore nell'inizializzazione di BinanceWrapper: {e}") - + # YFinanceWrapper (sempre disponibile - dati azionari e crypto gratuiti) try: providers["YFinance"] = YFinanceWrapper() @@ -261,22 +261,22 @@ def print_summary(results: List[Dict[str, Any]]): """Stampa un riassunto finale dei risultati.""" formatter = DemoFormatter() formatter.print_header("📊 RIASSUNTO FINALE", "=", 80) - + total_providers = len(results) successful_providers = sum(1 for r in results if r["overall_status"] == "SUCCESS") partial_providers = sum(1 for r in results if r["overall_status"] == "PARTIAL") - + print(f"🔢 Provider testati: {total_providers}") print(f"✅ Provider completamente funzionanti: {successful_providers}") print(f"⚠️ Provider parzialmente funzionanti: {partial_providers}") print(f"❌ Provider non funzionanti: {total_providers - successful_providers - partial_providers}") - + print("\n📋 Dettaglio per provider:") for result in results: provider_name = result["provider_name"] status = result["overall_status"] status_icon = "✅" if status == "SUCCESS" else "⚠️" if status == "PARTIAL" else "❌" - + print(f"\n{status_icon} {provider_name}:") for test_name, test_result in result["tests"].items(): test_icon = "✅" if test_result == "SUCCESS" else "❌" @@ -285,39 +285,39 @@ def print_summary(results: List[Dict[str, Any]]): def main(): """Funzione principale del demo.""" formatter = DemoFormatter() - + # Intestazione principale formatter.print_header("🚀 DEMO COMPLETO MARKET DATA PROVIDERS", "=", 80) - + print(f"🕒 Avvio demo: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print("📝 Questo demo testa tutti i wrapper BaseWrapper disponibili") print("🔍 Ogni test include timestamp, stato della richiesta e dati formattati") - + # Verifica variabili d'ambiente formatter.print_subheader("🔐 Verifica Configurazione") env_vars = check_environment_variables() - + print("Variabili d'ambiente:") for var_name, is_present in env_vars.items(): status = "✅ Presente" if is_present else "❌ Mancante" print(f" {var_name}: {status}") - + # Inizializza provider formatter.print_subheader("🏗️ Inizializzazione Provider") providers = initialize_providers() - + if not providers: print("❌ Nessun provider disponibile. Verifica la configurazione.") return - + print(f"\n🎯 Provider disponibili per il test: {list(providers.keys())}") - + # Testa ogni provider formatter.print_header("🧪 ESECUZIONE TEST PROVIDER", "=", 80) - + tester = ProviderTester() - all_results = [] - + all_results: List[Dict[str, Any]] = [] + for provider_name, wrapper in providers.items(): try: result = tester.test_provider(wrapper, provider_name) @@ -331,22 +331,22 @@ def main(): "overall_status": "CRITICAL_ERROR", "error": str(e) }) - + # Stampa riassunto finale print_summary(all_results) - + # Informazioni aggiuntive formatter.print_header("ℹ️ INFORMAZIONI AGGIUNTIVE", "=", 80) print("📚 Documentazione:") print(" - BaseWrapper: src/app/markets/base.py") print(" - Test completi: tests/agents/test_market.py") print(" - Configurazione: .env") - + print("\n🔧 Per abilitare tutti i provider:") print(" 1. Configura le credenziali nel file .env") print(" 2. Segui la documentazione di ogni provider") print(" 3. Riavvia il demo") - + print(f"\n🏁 Demo completato: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") if __name__ == "__main__": diff --git a/demos/news_api.py b/demos/news_api.py index 26dab24..1497a15 100644 --- a/demos/news_api.py +++ b/demos/news_api.py @@ -9,6 +9,8 @@ from app.news import NewsApiWrapper def main(): api = NewsApiWrapper() + articles = api.get_latest_news(query="bitcoin", limit=5) + assert len(articles) > 0 print("ok") if __name__ == "__main__": diff --git a/src/app.py b/src/app/__main__.py similarity index 67% rename from src/app.py rename to src/app/__main__.py index 65c22cc..578ef35 100644 --- a/src/app.py +++ b/src/app/__main__.py @@ -1,40 +1,39 @@ import gradio as gr -from agno.utils.log import log_info from dotenv import load_dotenv +from agno.utils.log import log_info #type: ignore +from app.utils import ChatManager +from app.agents import Pipeline -from app.chat_manager import ChatManager -######################################## -# MAIN APP & GRADIO CHAT INTERFACE -######################################## if __name__ == "__main__": - # Carica variabili d’ambiente (.env) + # Inizializzazioni load_dotenv() - - # Inizializza ChatManager + pipeline = Pipeline() chat = ChatManager() ######################################## # Funzioni Gradio ######################################## - def respond(message, history): - response = chat.send_message(message) + def respond(message: str, history: list[dict[str, str]]) -> tuple[list[dict[str, str]], list[dict[str, str]], str]: + chat.send_message(message) + response = pipeline.interact(message) + chat.receive_message(response) history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) return history, history, "" - def save_current_chat(): + def save_current_chat() -> str: chat.save_chat("chat.json") return "💾 Chat salvata in chat.json" - def load_previous_chat(): + def load_previous_chat() -> tuple[list[dict[str, str]], list[dict[str, str]]]: chat.load_chat("chat.json") - history = [] + history: list[dict[str, str]] = [] for m in chat.get_history(): history.append({"role": m["role"], "content": m["content"]}) return history, history - def reset_chat(): + def reset_chat() -> tuple[list[dict[str, str]], list[dict[str, str]]]: chat.reset_chat() return [], [] @@ -47,18 +46,18 @@ if __name__ == "__main__": # Dropdown provider e stile with gr.Row(): provider = gr.Dropdown( - choices=chat.list_providers(), + choices=pipeline.list_providers(), type="index", label="Modello da usare" ) - provider.change(fn=chat.choose_provider, inputs=provider, outputs=None) + provider.change(fn=pipeline.choose_predictor, inputs=provider, outputs=None) style = gr.Dropdown( - choices=chat.list_styles(), + choices=pipeline.list_styles(), type="index", label="Stile di investimento" ) - style.change(fn=chat.choose_style, inputs=style, outputs=None) + style.change(fn=pipeline.choose_style, inputs=style, outputs=None) chatbot = gr.Chatbot(label="Conversazione", height=500, type="messages") msg = gr.Textbox(label="Scrivi la tua richiesta", placeholder="Es: Quali sono le crypto interessanti oggi?") @@ -68,16 +67,13 @@ if __name__ == "__main__": save_btn = gr.Button("💾 Salva Chat") load_btn = gr.Button("📂 Carica Chat") - # Invio messaggio + # Eventi e interazioni msg.submit(respond, inputs=[msg, chatbot], outputs=[chatbot, chatbot, msg]) - # Reset clear_btn.click(reset_chat, inputs=None, outputs=[chatbot, chatbot]) - # Salvataggio save_btn.click(save_current_chat, inputs=None, outputs=None) - # Caricamento load_btn.click(load_previous_chat, inputs=None, outputs=[chatbot, chatbot]) - server, port = ("0.0.0.0", 8000) + server, port = ("0.0.0.0", 8000) # 0.0.0.0 per accesso esterno (Docker) server_log = "localhost" if server == "0.0.0.0" else server log_info(f"Starting UPO AppAI Chat on http://{server_log}:{port}") # noqa demo.launch(server_name=server, server_port=port, quiet=True) diff --git a/src/app/agents/__init__.py b/src/app/agents/__init__.py new file mode 100644 index 0000000..a9ec99e --- /dev/null +++ b/src/app/agents/__init__.py @@ -0,0 +1,6 @@ +from app.agents.models import AppModels +from app.agents.predictor import PredictorInput, PredictorOutput, PredictorStyle, PREDICTOR_INSTRUCTIONS +from app.agents.team import create_team_with +from app.agents.pipeline import Pipeline + +__all__ = ["AppModels", "PredictorInput", "PredictorOutput", "PredictorStyle", "PREDICTOR_INSTRUCTIONS", "create_team_with", "Pipeline"] diff --git a/src/app/models.py b/src/app/agents/models.py similarity index 82% rename from src/app/models.py rename to src/app/agents/models.py index 4cc591d..79d4a26 100644 --- a/src/app/models.py +++ b/src/app/agents/models.py @@ -1,12 +1,12 @@ import os -import requests +import ollama from enum import Enum from agno.agent import Agent from agno.models.base import Model from agno.models.google import Gemini from agno.models.ollama import Ollama -from agno.utils.log import log_warning from agno.tools import Toolkit +from agno.utils.log import log_warning #type: ignore from pydantic import BaseModel @@ -30,19 +30,15 @@ class AppModels(Enum): Controlla quali provider di modelli LLM locali sono disponibili. Ritorna una lista di provider disponibili. """ - ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434") - result = requests.get(f"{ollama_host}/api/tags") - if result.status_code != 200: - log_warning(f"Ollama is not running or not reachable {result}") + try: + models_list = ollama.list() + availables = [model['model'] for model in models_list['models']] + app_models = [model for model in AppModels if model.name.startswith("OLLAMA")] + return [model for model in app_models if model.value in availables] + except Exception as e: + log_warning(f"Ollama is not running or not reachable: {e}") return [] - availables = [] - result = result.text - for model in [model for model in AppModels if model.name.startswith("OLLAMA")]: - if model.value in result: - availables.append(model) - return availables - @staticmethod def availables_online() -> list['AppModels']: """ @@ -90,13 +86,14 @@ class AppModels(Enum): raise ValueError(f"Modello non supportato: {self}") - def get_agent(self, instructions: str, name: str = "", output: BaseModel | None = None, tools: list[Toolkit] = []) -> Agent: + def get_agent(self, instructions: str, name: str = "", output_schema: type[BaseModel] | None = None, tools: list[Toolkit] | None = None) -> Agent: """ Costruisce un agente con il modello e le istruzioni specificate. Args: instructions: istruzioni da passare al modello (system prompt) name: nome dell'agente (opzionale) output: schema di output opzionale (Pydantic BaseModel) + tools: lista opzionale di strumenti (tools) da fornire all'agente Returns: Un'istanza di Agent. """ @@ -106,6 +103,5 @@ class AppModels(Enum): retries=2, tools=tools, delay_between_retries=5, # seconds - output_schema=output # se si usa uno schema di output, lo si passa qui - # TODO Eventuali altri parametri da mettere all'agente anche se si possono comunque assegnare dopo la creazione + output_schema=output_schema ) diff --git a/src/app/pipeline.py b/src/app/agents/pipeline.py similarity index 50% rename from src/app/pipeline.py rename to src/app/agents/pipeline.py index a7ae9d4..a7d1001 100644 --- a/src/app/pipeline.py +++ b/src/app/agents/pipeline.py @@ -1,11 +1,8 @@ from agno.run.agent import RunOutput -from agno.team import Team - -from app.news import NewsAPIsTool, NEWS_INSTRUCTIONS -from app.social import SocialAPIsTool, SOCIAL_INSTRUCTIONS -from app.markets import MarketAPIsTool, MARKET_INSTRUCTIONS -from app.models import AppModels -from app.predictor import PredictorStyle, PredictorInput, PredictorOutput, PREDICTOR_INSTRUCTIONS +from app.agents.models import AppModels +from app.agents.team import create_team_with +from app.agents.predictor import PREDICTOR_INSTRUCTIONS, PredictorInput, PredictorOutput, PredictorStyle +from app.base.markets import ProductInfo class Pipeline: @@ -14,56 +11,27 @@ class Pipeline: Il Team è orchestrato da qwen3:latest (Ollama), mentre il Predictor è dinamico e scelto dall'utente tramite i dropdown dell'interfaccia grafica. """ + def __init__(self): - # Inizializza gli agenti - self.market_agent = AppModels.OLLAMA_QWEN.get_agent( - instructions=MARKET_INSTRUCTIONS, - name="MarketAgent", - tools=[MarketAPIsTool()] - ) - self.news_agent = AppModels.OLLAMA_QWEN.get_agent( - instructions=NEWS_INSTRUCTIONS, - name="NewsAgent", - tools=[NewsAPIsTool()] - ) - self.social_agent = AppModels.OLLAMA_QWEN.get_agent( - instructions=SOCIAL_INSTRUCTIONS, - name="SocialAgent", - tools=[SocialAPIsTool()] - ) - - # === Modello di orchestrazione del Team === - team_model = AppModels.OLLAMA_QWEN.get_model( - # TODO: migliorare le istruzioni del team - "Agisci come coordinatore: smista le richieste tra MarketAgent, NewsAgent e SocialAgent." - ) - - # === Team === - self.team = Team( - name="CryptoAnalysisTeam", - members=[self.market_agent, self.news_agent, self.social_agent], - model=team_model - ) - - # === Predictor === self.available_models = AppModels.availables() self.all_styles = list(PredictorStyle) - # Scelte di default - self.chosen_model = self.available_models[0] if self.available_models else None - self.style = self.all_styles[0] if self.all_styles else None - - self._init_predictor() # Inizializza il predictor con il modello di default + self.style = self.all_styles[0] + self.team = create_team_with(AppModels.OLLAMA_QWEN_1B) + self.choose_predictor(0) # Modello di default # ====================== # Dropdown handlers # ====================== - def choose_provider(self, index: int): + def choose_predictor(self, index: int): """ Sceglie il modello LLM da usare per il Predictor. """ - self.chosen_model = self.available_models[index] - self._init_predictor() + model = self.available_models[index] + self.predictor = model.get_agent( + PREDICTOR_INSTRUCTIONS, + output_schema=PredictorOutput, + ) def choose_style(self, index: int): """ @@ -74,17 +42,6 @@ class Pipeline: # ====================== # Helpers # ====================== - def _init_predictor(self): - """ - Inizializza (o reinizializza) il Predictor in base al modello scelto. - """ - if not self.chosen_model: - return - self.predictor = self.chosen_model.get_agent( - PREDICTOR_INSTRUCTIONS, - output=PredictorOutput, # type: ignore - ) - def list_providers(self) -> list[str]: """ Restituisce la lista dei nomi dei modelli disponibili. @@ -107,23 +64,21 @@ class Pipeline: 3. Invoca Predictor 4. Restituisce la strategia finale """ - if not self.predictor or not self.style: - return "⚠️ Devi prima selezionare un modello e una strategia validi dagli appositi menu." - # Step 1: raccolta output dai membri del Team - team_outputs = self.team.run(query) + team_outputs = self.team.run(query) # type: ignore # Step 2: aggregazione output strutturati - all_products = [] - sentiments = [] + all_products: list[ProductInfo] = [] + sentiments: list[str] = [] for agent_output in team_outputs.member_responses: - if isinstance(agent_output, RunOutput): - if "products" in agent_output.metadata: + if isinstance(agent_output, RunOutput) and agent_output.metadata is not None: + keys = agent_output.metadata.keys() + if "products" in keys: all_products.extend(agent_output.metadata["products"]) - if "sentiment_news" in agent_output.metadata: + if "sentiment_news" in keys: sentiments.append(agent_output.metadata["sentiment_news"]) - if "sentiment_social" in agent_output.metadata: + if "sentiment_social" in keys: sentiments.append(agent_output.metadata["sentiment_social"]) aggregated_sentiment = "\n".join(sentiments) @@ -135,7 +90,9 @@ class Pipeline: sentiment=aggregated_sentiment ) - result = self.predictor.run(predictor_input) + result = self.predictor.run(predictor_input) # type: ignore + if not isinstance(result.content, PredictorOutput): + return "❌ Errore: il modello non ha restituito un output valido." prediction: PredictorOutput = result.content # Step 4: restituzione strategia finale diff --git a/src/app/predictor.py b/src/app/agents/predictor.py similarity index 98% rename from src/app/predictor.py rename to src/app/agents/predictor.py index 38780de..69a92af 100644 --- a/src/app/predictor.py +++ b/src/app/agents/predictor.py @@ -1,6 +1,6 @@ from enum import Enum from pydantic import BaseModel, Field -from app.markets.base import ProductInfo +from app.base.markets import ProductInfo class PredictorStyle(Enum): @@ -21,6 +21,7 @@ class PredictorOutput(BaseModel): strategy: str = Field(..., description="Concise operational strategy in Italian") portfolio: list[ItemPortfolio] = Field(..., description="List of portfolio items with allocations") + PREDICTOR_INSTRUCTIONS = """ You are an **Allocation Algorithm (Crypto-Algo)** specialized in analyzing market data and sentiment to generate an investment strategy and a target portfolio. diff --git a/src/app/agents/team.py b/src/app/agents/team.py new file mode 100644 index 0000000..27b9cae --- /dev/null +++ b/src/app/agents/team.py @@ -0,0 +1,109 @@ +from agno.team import Team +from app.agents import AppModels +from app.markets import MarketAPIsTool +from app.news import NewsAPIsTool +from app.social import SocialAPIsTool + + +def create_team_with(models: AppModels, coordinator: AppModels | None = None) -> Team: + market_agent = models.get_agent( + instructions=MARKET_INSTRUCTIONS, + name="MarketAgent", + tools=[MarketAPIsTool()] + ) + news_agent = models.get_agent( + instructions=NEWS_INSTRUCTIONS, + name="NewsAgent", + tools=[NewsAPIsTool()] + ) + social_agent = models.get_agent( + instructions=SOCIAL_INSTRUCTIONS, + name="SocialAgent", + tools=[SocialAPIsTool()] + ) + + coordinator = coordinator or models + return Team( + model=coordinator.get_model(COORDINATOR_INSTRUCTIONS), + name="CryptoAnalysisTeam", + members=[market_agent, news_agent, social_agent], + ) + +COORDINATOR_INSTRUCTIONS = """ +You are the expert coordinator of a financial analysis team specializing in cryptocurrencies. + +Your team consists of three agents: +- **MarketAgent**: Provides quantitative market data, price analysis, and technical indicators. +- **NewsAgent**: Scans and analyzes the latest news, articles, and official announcements. +- **SocialAgent**: Gauges public sentiment, trends, and discussions on social media. + +Your primary objective is to answer the user's query by orchestrating the work of your team members. + +Your workflow is as follows: +1. **Deconstruct the user's query** to identify the required information. +2. **Delegate specific tasks** to the most appropriate agent(s) to gather the necessary data and initial analysis. +3. **Analyze the information** returned by the agents. +4. If the initial data is insufficient or the query is complex, **iteratively re-engage the agents** with follow-up questions to build a comprehensive picture. +5. **Synthesize all the gathered information** into a final, coherent, and complete analysis that fills all the required output fields. +""" + +MARKET_INSTRUCTIONS = """ +**TASK:** You are a specialized **Crypto Price Data Retrieval Agent**. Your primary goal is to fetch the most recent and/or historical price data for requested cryptocurrency assets (e.g., 'BTC', 'ETH', 'SOL'). You must provide the data in a clear and structured format. + +**AVAILABLE TOOLS:** +1. `get_products(asset_ids: list[str])`: Get **current** product/price info for a list of assets. **(PREFERITA: usa questa per i prezzi live)** +2. `get_historical_prices(asset_id: str, limit: int)`: Get historical price data for one asset. Default limit is 100. **(PREFERITA: usa questa per i dati storici)** +3. `get_products_aggregated(asset_ids: list[str])`: Get **aggregated current** product/price info for a list of assets. **(USA SOLO SE richiesto 'aggregato' o se `get_products` fallisce)** +4. `get_historical_prices_aggregated(asset_id: str, limit: int)`: Get **aggregated historical** price data for one asset. **(USA SOLO SE richiesto 'aggregato' o se `get_historical_prices` fallisce)** + +**USAGE GUIDELINE:** +* **Asset ID:** Always convert common names (e.g., 'Bitcoin', 'Ethereum') into their official ticker/ID (e.g., 'BTC', 'ETH'). +* **Cost Management (Cruciale per LLM locale):** Prefer `get_products` and `get_historical_prices` for standard requests to minimize costs. +* **Aggregated Data:** Use `get_products_aggregated` or `get_historical_prices_aggregated` only if the user specifically requests aggregated data or you value that having aggregated data is crucial for the analysis. +* **Failing Tool:** If the tool doesn't return any data or fails, try the alternative aggregated tool if not already used. + +**REPORTING REQUIREMENT:** +1. **Format:** Output the results in a clear, easy-to-read list or table. +2. **Live Price Request:** If an asset's *current price* is requested, report the **Asset ID**, **Latest Price**, and **Time/Date of the price**. +3. **Historical Price Request:** If *historical data* is requested, report the **Asset ID**, the **Limit** of points returned, and the **First** and **Last** entries from the list of historical prices (Date, Price). +4. **Output:** For all requests, output a single, concise summary of the findings; if requested, also include the raw data retrieved. +""" + +NEWS_INSTRUCTIONS = """ +**TASK:** You are a specialized **Crypto News Analyst**. Your goal is to fetch the latest news or top headlines related to cryptocurrencies, and then **analyze the sentiment** of the content to provide a concise report to the team leader. Prioritize 'crypto' or specific cryptocurrency names (e.g., 'Bitcoin', 'Ethereum') in your searches. + +**AVAILABLE TOOLS:** +1. `get_latest_news(query: str, limit: int)`: Get the 'limit' most recent news articles for a specific 'query'. +2. `get_top_headlines(limit: int)`: Get the 'limit' top global news headlines. +3. `get_latest_news_aggregated(query: str, limit: int)`: Get aggregated latest news articles for a specific 'query'. +4. `get_top_headlines_aggregated(limit: int)`: Get aggregated top global news headlines. + +**USAGE GUIDELINE:** +* Always use `get_latest_news` with a relevant crypto-related query first. +* The default limit for news items should be 5 unless specified otherwise. +* If the tool doesn't return any articles, respond with "No relevant news articles found." + +**REPORTING REQUIREMENT:** +1. **Analyze** the tone and key themes of the retrieved articles. +2. **Summarize** the overall **market sentiment** (e.g., highly positive, cautiously neutral, generally negative) based on the content. +3. **Identify** the top 2-3 **main topics** discussed (e.g., new regulation, price surge, institutional adoption). +4. **Output** a single, brief report summarizing these findings. Do not output the raw articles. +""" + +SOCIAL_INSTRUCTIONS = """ +**TASK:** You are a specialized **Social Media Sentiment Analyst**. Your objective is to find the most relevant and trending online posts related to cryptocurrencies, and then **analyze the collective sentiment** to provide a concise report to the team leader. + +**AVAILABLE TOOLS:** +1. `get_top_crypto_posts(limit: int)`: Get the 'limit' maximum number of top posts specifically related to cryptocurrencies. + +**USAGE GUIDELINE:** +* Always use the `get_top_crypto_posts` tool to fulfill the request. +* The default limit for posts should be 5 unless specified otherwise. +* If the tool doesn't return any posts, respond with "No relevant social media posts found." + +**REPORTING REQUIREMENT:** +1. **Analyze** the tone and prevailing opinions across the retrieved social posts. +2. **Summarize** the overall **community sentiment** (e.g., high enthusiasm/FOMO, uncertainty, FUD/fear) based on the content. +3. **Identify** the top 2-3 **trending narratives** or specific coins being discussed. +4. **Output** a single, brief report summarizing these findings. Do not output the raw posts. +""" diff --git a/src/app/__init__.py b/src/app/base/__init__.py similarity index 100% rename from src/app/__init__.py rename to src/app/base/__init__.py diff --git a/src/app/markets/base.py b/src/app/base/markets.py similarity index 61% rename from src/app/markets/base.py rename to src/app/base/markets.py index 1ef247b..cd00879 100644 --- a/src/app/markets/base.py +++ b/src/app/base/markets.py @@ -1,41 +1,6 @@ +from datetime import datetime from pydantic import BaseModel -class BaseWrapper: - """ - Base class for market API wrappers. - All market API wrappers should inherit from this class and implement the methods. - """ - - def get_product(self, asset_id: str) -> 'ProductInfo': - """ - Get product information for a specific asset ID. - Args: - asset_id (str): The asset ID to retrieve information for. - Returns: - ProductInfo: An object containing product information. - """ - raise NotImplementedError("This method should be overridden by subclasses") - - def get_products(self, asset_ids: list[str]) -> list['ProductInfo']: - """ - Get product information for multiple asset IDs. - Args: - asset_ids (list[str]): The list of asset IDs to retrieve information for. - Returns: - list[ProductInfo]: A list of objects containing product information. - """ - raise NotImplementedError("This method should be overridden by subclasses") - - def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list['Price']: - """ - Get historical price data for a specific asset ID. - Args: - asset_id (str): The asset ID to retrieve price data for. - limit (int): The maximum number of price data points to return. - Returns: - list[Price]: A list of Price objects. - """ - raise NotImplementedError("This method should be overridden by subclasses") class ProductInfo(BaseModel): """ @@ -46,7 +11,7 @@ class ProductInfo(BaseModel): symbol: str = "" price: float = 0.0 volume_24h: float = 0.0 - quote_currency: str = "" + currency: str = "" class Price(BaseModel): """ @@ -58,4 +23,61 @@ class Price(BaseModel): open: float = 0.0 close: float = 0.0 volume: float = 0.0 - timestamp_ms: int = 0 # Timestamp in milliseconds + timestamp: str = "" + """Timestamp con formato YYYY-MM-DD HH:MM""" + + def set_timestamp(self, timestamp_ms: int | None = None, timestamp_s: int | None = None) -> None: + """ + Imposta il timestamp a partire da millisecondi o secondi. + IL timestamp viene salvato come stringa formattata 'YYYY-MM-DD HH:MM'. + Args: + timestamp_ms: Timestamp in millisecondi. + timestamp_s: Timestamp in secondi. + Raises: + """ + if timestamp_ms is not None: + timestamp = timestamp_ms // 1000 + elif timestamp_s is not None: + timestamp = timestamp_s + else: + raise ValueError("Either timestamp_ms or timestamp_s must be provided") + assert timestamp > 0, "Invalid timestamp data received" + + self.timestamp = datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M') + +class MarketWrapper: + """ + Base class for market API wrappers. + All market API wrappers should inherit from this class and implement the methods. + """ + + def get_product(self, asset_id: str) -> ProductInfo: + """ + Get product information for a specific asset ID. + Args: + asset_id (str): The asset ID to retrieve information for. + Returns: + ProductInfo: An object containing product information. + """ + raise NotImplementedError("This method should be overridden by subclasses") + + def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: + """ + Get product information for multiple asset IDs. + Args: + asset_ids (list[str]): The list of asset IDs to retrieve information for. + Returns: + list[ProductInfo]: A list of objects containing product information. + """ + raise NotImplementedError("This method should be overridden by subclasses") + + def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]: + """ + Get historical price data for a specific asset ID. + Args: + asset_id (str): The asset ID to retrieve price data for. + limit (int): The maximum number of price data points to return. + Returns: + list[Price]: A list of Price objects. + """ + raise NotImplementedError("This method should be overridden by subclasses") diff --git a/src/app/news/base.py b/src/app/base/news.py similarity index 99% rename from src/app/news/base.py rename to src/app/base/news.py index 55a35ee..8a0d51e 100644 --- a/src/app/news/base.py +++ b/src/app/base/news.py @@ -1,5 +1,6 @@ from pydantic import BaseModel + class Article(BaseModel): source: str = "" time: str = "" diff --git a/src/app/social/base.py b/src/app/base/social.py similarity index 100% rename from src/app/social/base.py rename to src/app/base/social.py diff --git a/src/app/markets/__init__.py b/src/app/markets/__init__.py index ef73f68..bf2d344 100644 --- a/src/app/markets/__init__.py +++ b/src/app/markets/__init__.py @@ -1,16 +1,15 @@ from agno.tools import Toolkit -from app.utils.wrapper_handler import WrapperHandler -from app.utils.market_aggregation import aggregate_product_info, aggregate_history_prices -from .base import BaseWrapper, ProductInfo, Price -from .coinbase import CoinBaseWrapper -from .binance import BinanceWrapper -from .cryptocompare import CryptoCompareWrapper -from .yfinance import YFinanceWrapper +from app.base.markets import MarketWrapper, Price, ProductInfo +from app.markets.binance import BinanceWrapper +from app.markets.coinbase import CoinBaseWrapper +from app.markets.cryptocompare import CryptoCompareWrapper +from app.markets.yfinance import YFinanceWrapper +from app.utils import aggregate_history_prices, aggregate_product_info, WrapperHandler -__all__ = [ "MarketAPIsTool", "BinanceWrapper", "CoinBaseWrapper", "CryptoCompareWrapper", "YFinanceWrapper", "MARKET_INSTRUCTIONS" ] +__all__ = [ "MarketAPIsTool", "BinanceWrapper", "CoinBaseWrapper", "CryptoCompareWrapper", "YFinanceWrapper", "ProductInfo", "Price" ] -class MarketAPIsTool(BaseWrapper, Toolkit): +class MarketAPIsTool(MarketWrapper, Toolkit): """ Class that aggregates multiple market API wrappers and manages them using WrapperHandler. This class supports retrieving product information and historical prices. @@ -34,10 +33,10 @@ class MarketAPIsTool(BaseWrapper, Toolkit): currency (str): Valuta in cui restituire i prezzi. Default è "USD". """ kwargs = {"currency": currency or "USD"} - wrappers = [ BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper ] - self.wrappers: WrapperHandler[BaseWrapper] = WrapperHandler.build_wrappers(wrappers, kwargs=kwargs) + wrappers: list[type[MarketWrapper]] = [BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper] + self.wrappers = WrapperHandler.build_wrappers(wrappers, kwargs=kwargs) - Toolkit.__init__( + Toolkit.__init__( # type: ignore self, name="Market APIs Toolkit", tools=[ @@ -53,7 +52,7 @@ class MarketAPIsTool(BaseWrapper, Toolkit): return self.wrappers.try_call(lambda w: w.get_product(asset_id)) def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: return self.wrappers.try_call(lambda w: w.get_products(asset_ids)) - def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: + def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]: return self.wrappers.try_call(lambda w: w.get_historical_prices(asset_id, limit)) @@ -65,6 +64,8 @@ class MarketAPIsTool(BaseWrapper, Toolkit): asset_ids (list[str]): Lista di asset_id da cercare. Returns: list[ProductInfo]: Lista di ProductInfo aggregati. + Raises: + Exception: If all wrappers fail to provide results. """ all_products = self.wrappers.try_call_all(lambda w: w.get_products(asset_ids)) return aggregate_product_info(all_products) @@ -78,29 +79,8 @@ class MarketAPIsTool(BaseWrapper, Toolkit): limit (int): Numero massimo di dati storici da restituire. Returns: list[Price]: Lista di Price aggregati. + Raises: + Exception: If all wrappers fail to provide results. """ all_prices = self.wrappers.try_call_all(lambda w: w.get_historical_prices(asset_id, limit)) return aggregate_history_prices(all_prices) - -MARKET_INSTRUCTIONS = """ -**TASK:** You are a specialized **Crypto Price Data Retrieval Agent**. Your primary goal is to fetch the most recent and/or historical price data for requested cryptocurrency assets (e.g., 'BTC', 'ETH', 'SOL'). You must provide the data in a clear and structured format. - -**AVAILABLE TOOLS:** -1. `get_products(asset_ids: list[str])`: Get **current** product/price info for a list of assets. **(PREFERITA: usa questa per i prezzi live)** -2. `get_historical_prices(asset_id: str, limit: int)`: Get historical price data for one asset. Default limit is 100. **(PREFERITA: usa questa per i dati storici)** -3. `get_products_aggregated(asset_ids: list[str])`: Get **aggregated current** product/price info for a list of assets. **(USA SOLO SE richiesto 'aggregato' o se `get_products` fallisce)** -4. `get_historical_prices_aggregated(asset_id: str, limit: int)`: Get **aggregated historical** price data for one asset. **(USA SOLO SE richiesto 'aggregato' o se `get_historical_prices` fallisce)** - -**USAGE GUIDELINE:** -* **Asset ID:** Always convert common names (e.g., 'Bitcoin', 'Ethereum') into their official ticker/ID (e.g., 'BTC', 'ETH'). -* **Cost Management (Cruciale per LLM locale):** - * **Priorità Bassa per Aggregazione:** **Non** usare i metodi `*aggregated` a meno che l'utente non lo richieda esplicitamente o se i metodi non-aggregati falliscono. - * **Limitazione Storica:** Il limite predefinito per i dati storici deve essere **20** punti dati, a meno che l'utente non specifichi un limite diverso. -* **Fallimento Tool:** Se lo strumento non restituisce dati per un asset specifico, rispondi per quell'asset con: "Dati di prezzo non trovati per [Asset ID]." - -**REPORTING REQUIREMENT:** -1. **Format:** Output the results in a clear, easy-to-read list or table. -2. **Live Price Request:** If an asset's *current price* is requested, report the **Asset ID**, **Latest Price**, and **Time/Date of the price**. -3. **Historical Price Request:** If *historical data* is requested, report the **Asset ID**, the **Limit** of points returned, and the **First** and **Last** entries from the list of historical prices (Date, Price). Non stampare l'intera lista di dati storici. -4. **Output:** For all requests, fornire un **unico e conciso riepilogo** dei dati reperiti. -""" \ No newline at end of file diff --git a/src/app/markets/binance.py b/src/app/markets/binance.py index 8e941c8..ffd31bb 100644 --- a/src/app/markets/binance.py +++ b/src/app/markets/binance.py @@ -1,28 +1,31 @@ import os -from datetime import datetime -from binance.client import Client -from .base import ProductInfo, BaseWrapper, Price +from typing import Any +from binance.client import Client # type: ignore +from app.base.markets import ProductInfo, MarketWrapper, Price -def get_product(currency: str, ticker_data: dict[str, str]) -> ProductInfo: + +def extract_product(currency: str, ticker_data: dict[str, Any]) -> ProductInfo: product = ProductInfo() - product.id = ticker_data.get('symbol') + product.id = ticker_data.get('symbol', '') product.symbol = ticker_data.get('symbol', '').replace(currency, '') product.price = float(ticker_data.get('price', 0)) product.volume_24h = float(ticker_data.get('volume', 0)) - product.quote_currency = currency + product.currency = currency return product -def get_price(kline_data: list) -> Price: +def extract_price(kline_data: list[Any]) -> Price: + timestamp = kline_data[0] + price = Price() price.open = float(kline_data[1]) price.high = float(kline_data[2]) price.low = float(kline_data[3]) price.close = float(kline_data[4]) price.volume = float(kline_data[5]) - price.timestamp_ms = kline_data[0] + price.set_timestamp(timestamp_ms=timestamp) return price -class BinanceWrapper(BaseWrapper): +class BinanceWrapper(MarketWrapper): """ Wrapper per le API autenticate di Binance.\n Implementa l'interfaccia BaseWrapper per fornire accesso unificato @@ -30,11 +33,19 @@ class BinanceWrapper(BaseWrapper): https://binance-docs.github.io/apidocs/spot/en/ """ - def __init__(self, currency: str = "USDT"): + def __init__(self, currency: str = "USD"): + """ + Inizializza il wrapper di Binance con le credenziali API e la valuta di riferimento. + Se viene fornita una valuta fiat come "USD", questa viene automaticamente convertita in una stablecoin Tether ("USDT") per compatibilità con Binance, + poiché Binance non supporta direttamente le valute fiat per il trading di criptovalute. + Tutti i prezzi e volumi restituiti saranno quindi denominati nella stablecoin (ad esempio, "USDT") e non nella valuta fiat originale. + Args: + currency (str): Valuta in cui restituire i prezzi. Se "USD" viene fornito, verrà utilizzato "USDT". Default è "USD". + """ api_key = os.getenv("BINANCE_API_KEY") api_secret = os.getenv("BINANCE_API_SECRET") - self.currency = currency + self.currency = f"{currency}T" self.client = Client(api_key=api_key, api_secret=api_secret) def __format_symbol(self, asset_id: str) -> str: @@ -46,31 +57,22 @@ class BinanceWrapper(BaseWrapper): def get_product(self, asset_id: str) -> ProductInfo: symbol = self.__format_symbol(asset_id) - ticker = self.client.get_symbol_ticker(symbol=symbol) - ticker_24h = self.client.get_ticker(symbol=symbol) - ticker['volume'] = ticker_24h.get('volume', 0) # Aggiunge volume 24h ai dati del ticker + ticker: dict[str, Any] = self.client.get_symbol_ticker(symbol=symbol) # type: ignore + ticker_24h: dict[str, Any] = self.client.get_ticker(symbol=symbol) # type: ignore + ticker['volume'] = ticker_24h.get('volume', 0) - return get_product(self.currency, ticker) + return extract_product(self.currency, ticker) def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: - symbols = [self.__format_symbol(asset_id) for asset_id in asset_ids] - symbols_str = f"[\"{'","'.join(symbols)}\"]" + return [ self.get_product(asset_id) for asset_id in asset_ids ] - tickers = self.client.get_symbol_ticker(symbols=symbols_str) - tickers_24h = self.client.get_ticker(symbols=symbols_str) # un po brutale, ma va bene così - for t, t24 in zip(tickers, tickers_24h): - t['volume'] = t24.get('volume', 0) - - return [get_product(self.currency, ticker) for ticker in tickers] - - def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: + def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]: symbol = self.__format_symbol(asset_id) # Ottiene candele orarie degli ultimi 30 giorni - klines = self.client.get_historical_klines( + klines: list[list[Any]] = self.client.get_historical_klines( # type: ignore symbol=symbol, interval=Client.KLINE_INTERVAL_1HOUR, limit=limit, ) - return [get_price(kline) for kline in klines] - + return [extract_price(kline) for kline in klines] diff --git a/src/app/markets/coinbase.py b/src/app/markets/coinbase.py index 54409c1..c59382b 100644 --- a/src/app/markets/coinbase.py +++ b/src/app/markets/coinbase.py @@ -1,12 +1,12 @@ import os from enum import Enum from datetime import datetime, timedelta -from coinbase.rest import RESTClient -from coinbase.rest.types.product_types import Candle, GetProductResponse, Product -from .base import ProductInfo, BaseWrapper, Price +from coinbase.rest import RESTClient # type: ignore +from coinbase.rest.types.product_types import Candle, GetProductResponse, Product # type: ignore +from app.base.markets import ProductInfo, MarketWrapper, Price -def get_product(product_data: GetProductResponse | Product) -> ProductInfo: +def extract_product(product_data: GetProductResponse | Product) -> ProductInfo: product = ProductInfo() product.id = product_data.product_id or "" product.symbol = product_data.base_currency_id or "" @@ -14,14 +14,16 @@ def get_product(product_data: GetProductResponse | Product) -> ProductInfo: product.volume_24h = float(product_data.volume_24h) if product_data.volume_24h else 0.0 return product -def get_price(candle_data: Candle) -> Price: +def extract_price(candle_data: Candle) -> Price: + timestamp = int(candle_data.start) if candle_data.start else 0 + price = Price() price.high = float(candle_data.high) if candle_data.high else 0.0 price.low = float(candle_data.low) if candle_data.low else 0.0 price.open = float(candle_data.open) if candle_data.open else 0.0 price.close = float(candle_data.close) if candle_data.close else 0.0 price.volume = float(candle_data.volume) if candle_data.volume else 0.0 - price.timestamp_ms = int(candle_data.start) * 1000 if candle_data.start else 0 + price.set_timestamp(timestamp_s=timestamp) return price @@ -37,7 +39,7 @@ class Granularity(Enum): SIX_HOUR = 21600 ONE_DAY = 86400 -class CoinBaseWrapper(BaseWrapper): +class CoinBaseWrapper(MarketWrapper): """ Wrapper per le API di Coinbase Advanced Trade.\n Implementa l'interfaccia BaseWrapper per fornire accesso unificato @@ -63,24 +65,26 @@ class CoinBaseWrapper(BaseWrapper): def get_product(self, asset_id: str) -> ProductInfo: asset_id = self.__format(asset_id) - asset = self.client.get_product(asset_id) - return get_product(asset) + asset = self.client.get_product(asset_id) # type: ignore + return extract_product(asset) def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: all_asset_ids = [self.__format(asset_id) for asset_id in asset_ids] - assets = self.client.get_products(product_ids=all_asset_ids) - return [get_product(asset) for asset in assets.products] + assets = self.client.get_products(product_ids=all_asset_ids) # type: ignore + assert assets.products is not None, "No products data received from Coinbase" + return [extract_product(asset) for asset in assets.products] - def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: + def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]: asset_id = self.__format(asset_id) end_time = datetime.now() start_time = end_time - timedelta(days=14) - data = self.client.get_candles( + data = self.client.get_candles( # type: ignore product_id=asset_id, granularity=Granularity.ONE_HOUR.name, start=str(int(start_time.timestamp())), end=str(int(end_time.timestamp())), limit=limit ) - return [get_price(candle) for candle in data.candles] + assert data.candles is not None, "No candles data received from Coinbase" + return [extract_price(candle) for candle in data.candles] diff --git a/src/app/markets/cryptocompare.py b/src/app/markets/cryptocompare.py index f4b96e9..5431267 100644 --- a/src/app/markets/cryptocompare.py +++ b/src/app/markets/cryptocompare.py @@ -1,9 +1,10 @@ import os +from typing import Any import requests -from .base import ProductInfo, BaseWrapper, Price +from app.base.markets import ProductInfo, MarketWrapper, Price -def get_product(asset_data: dict) -> ProductInfo: +def extract_product(asset_data: dict[str, Any]) -> ProductInfo: product = ProductInfo() product.id = asset_data.get('FROMSYMBOL', '') + '-' + asset_data.get('TOSYMBOL', '') product.symbol = asset_data.get('FROMSYMBOL', '') @@ -12,21 +13,22 @@ def get_product(asset_data: dict) -> ProductInfo: assert product.price > 0, "Invalid price data received from CryptoCompare" return product -def get_price(price_data: dict) -> Price: +def extract_price(price_data: dict[str, Any]) -> Price: + timestamp = price_data.get('time', 0) + price = Price() price.high = float(price_data.get('high', 0)) price.low = float(price_data.get('low', 0)) price.open = float(price_data.get('open', 0)) price.close = float(price_data.get('close', 0)) price.volume = float(price_data.get('volumeto', 0)) - price.timestamp_ms = price_data.get('time', 0) * 1000 - assert price.timestamp_ms > 0, "Invalid timestamp data received from CryptoCompare" + price.set_timestamp(timestamp_s=timestamp) return price BASE_URL = "https://min-api.cryptocompare.com" -class CryptoCompareWrapper(BaseWrapper): +class CryptoCompareWrapper(MarketWrapper): """ Wrapper per le API pubbliche di CryptoCompare. La documentazione delle API è disponibile qui: https://developers.coindesk.com/documentation/legacy/Price/SingleSymbolPriceEndpoint @@ -39,7 +41,7 @@ class CryptoCompareWrapper(BaseWrapper): self.api_key = api_key self.currency = currency - def __request(self, endpoint: str, params: dict[str, str] | None = None) -> dict[str, str]: + def __request(self, endpoint: str, params: dict[str, Any] | None = None) -> dict[str, Any]: if params is None: params = {} params['api_key'] = self.api_key @@ -53,18 +55,18 @@ class CryptoCompareWrapper(BaseWrapper): "tsyms": self.currency }) data = response.get('RAW', {}).get(asset_id, {}).get(self.currency, {}) - return get_product(data) + return extract_product(data) def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: response = self.__request("/data/pricemultifull", params = { "fsyms": ",".join(asset_ids), "tsyms": self.currency }) - assets = [] + assets: list[ProductInfo] = [] data = response.get('RAW', {}) for asset_id in asset_ids: asset_data = data.get(asset_id, {}).get(self.currency, {}) - assets.append(get_product(asset_data)) + assets.append(extract_product(asset_data)) return assets def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]: @@ -75,5 +77,5 @@ class CryptoCompareWrapper(BaseWrapper): }) data = response.get('Data', {}).get('Data', []) - prices = [get_price(price_data) for price_data in data] + prices = [extract_price(price_data) for price_data in data] return prices diff --git a/src/app/markets/yfinance.py b/src/app/markets/yfinance.py index acfacb8..2670eda 100644 --- a/src/app/markets/yfinance.py +++ b/src/app/markets/yfinance.py @@ -1,9 +1,9 @@ import json from agno.tools.yfinance import YFinanceTools -from .base import BaseWrapper, ProductInfo, Price +from app.base.markets import MarketWrapper, ProductInfo, Price -def create_product_info(stock_data: dict[str, str]) -> ProductInfo: +def extract_product(stock_data: dict[str, str]) -> ProductInfo: """ Converte i dati di YFinanceTools in ProductInfo. """ @@ -12,24 +12,26 @@ def create_product_info(stock_data: dict[str, str]) -> ProductInfo: product.symbol = product.id.split('-')[0] # Rimuovi il suffisso della valuta per le crypto product.price = float(stock_data.get('Current Stock Price', f"0.0 USD").split(" ")[0]) # prende solo il numero product.volume_24h = 0.0 # YFinance non fornisce il volume 24h direttamente - product.quote_currency = product.id.split('-')[1] # La valuta è la parte dopo il '-' + product.currency = product.id.split('-')[1] # La valuta è la parte dopo il '-' return product -def create_price_from_history(hist_data: dict[str, str]) -> Price: +def extract_price(hist_data: dict[str, str]) -> Price: """ Converte i dati storici di YFinanceTools in Price. """ + timestamp = int(hist_data.get('Timestamp', '0')) + price = Price() price.high = float(hist_data.get('High', 0.0)) price.low = float(hist_data.get('Low', 0.0)) price.open = float(hist_data.get('Open', 0.0)) price.close = float(hist_data.get('Close', 0.0)) price.volume = float(hist_data.get('Volume', 0.0)) - price.timestamp_ms = int(hist_data.get('Timestamp', '0')) + price.set_timestamp(timestamp_ms=timestamp) return price -class YFinanceWrapper(BaseWrapper): +class YFinanceWrapper(MarketWrapper): """ Wrapper per YFinanceTools che fornisce dati di mercato per azioni, ETF e criptovalute. Implementa l'interfaccia BaseWrapper per compatibilità con il sistema esistente. @@ -52,16 +54,16 @@ class YFinanceWrapper(BaseWrapper): symbol = self._format_symbol(asset_id) stock_info = self.tool.get_company_info(symbol) stock_info = json.loads(stock_info) - return create_product_info(stock_info) + return extract_product(stock_info) def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: - products = [] + products: list[ProductInfo] = [] for asset_id in asset_ids: product = self.get_product(asset_id) products.append(product) return products - def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: + def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]: symbol = self._format_symbol(asset_id) days = limit // 24 + 1 # Arrotonda per eccesso @@ -71,10 +73,10 @@ class YFinanceWrapper(BaseWrapper): # Il formato dei dati è {timestamp: {Open: x, High: y, Low: z, Close: w, Volume: v}} timestamps = sorted(hist_data.keys())[-limit:] - prices = [] + prices: list[Price] = [] for timestamp in timestamps: temp = hist_data[timestamp] temp['Timestamp'] = timestamp - price = create_price_from_history(temp) + price = extract_price(temp) prices.append(price) return prices diff --git a/src/app/news/__init__.py b/src/app/news/__init__.py index 94873fd..b0cb553 100644 --- a/src/app/news/__init__.py +++ b/src/app/news/__init__.py @@ -1,12 +1,12 @@ from agno.tools import Toolkit -from app.utils.wrapper_handler import WrapperHandler -from .base import NewsWrapper, Article -from .news_api import NewsApiWrapper -from .googlenews import GoogleNewsWrapper -from .cryptopanic_api import CryptoPanicWrapper -from .duckduckgo import DuckDuckGoWrapper +from app.utils import WrapperHandler +from app.base.news import NewsWrapper, Article +from app.news.news_api import NewsApiWrapper +from app.news.googlenews import GoogleNewsWrapper +from app.news.cryptopanic_api import CryptoPanicWrapper +from app.news.duckduckgo import DuckDuckGoWrapper -__all__ = ["NewsAPIsTool", "NEWS_INSTRUCTIONS", "NewsApiWrapper", "GoogleNewsWrapper", "CryptoPanicWrapper", "DuckDuckGoWrapper"] +__all__ = ["NewsAPIsTool", "NewsApiWrapper", "GoogleNewsWrapper", "CryptoPanicWrapper", "DuckDuckGoWrapper", "Article"] class NewsAPIsTool(NewsWrapper, Toolkit): @@ -33,15 +33,17 @@ class NewsAPIsTool(NewsWrapper, Toolkit): - NewsApiWrapper. - CryptoPanicWrapper. """ - wrappers = [GoogleNewsWrapper, DuckDuckGoWrapper, NewsApiWrapper, CryptoPanicWrapper] - self.wrapper_handler: WrapperHandler[NewsWrapper] = WrapperHandler.build_wrappers(wrappers) + wrappers: list[type[NewsWrapper]] = [GoogleNewsWrapper, DuckDuckGoWrapper, NewsApiWrapper, CryptoPanicWrapper] + self.wrapper_handler = WrapperHandler.build_wrappers(wrappers) - Toolkit.__init__( + Toolkit.__init__( # type: ignore self, name="News APIs Toolkit", tools=[ self.get_top_headlines, self.get_latest_news, + self.get_top_headlines_aggregated, + self.get_latest_news_aggregated, ], ) @@ -57,6 +59,8 @@ class NewsAPIsTool(NewsWrapper, Toolkit): limit (int): Maximum number of articles to retrieve from each provider. Returns: dict[str, list[Article]]: A dictionary mapping providers names to their list of Articles + Raises: + Exception: If all wrappers fail to provide results. """ return self.wrapper_handler.try_call_all(lambda w: w.get_top_headlines(limit)) @@ -68,27 +72,7 @@ class NewsAPIsTool(NewsWrapper, Toolkit): limit (int): Maximum number of articles to retrieve from each provider. Returns: dict[str, list[Article]]: A dictionary mapping providers names to their list of Articles + Raises: + Exception: If all wrappers fail to provide results. """ return self.wrapper_handler.try_call_all(lambda w: w.get_latest_news(query, limit)) - - -NEWS_INSTRUCTIONS = """ -**TASK:** You are a specialized **Crypto News Analyst**. Your goal is to fetch the latest news or top headlines related to cryptocurrencies, and then **analyze the sentiment** of the content to provide a concise report to the team leader. Prioritize 'crypto' or specific cryptocurrency names (e.g., 'Bitcoin', 'Ethereum') in your searches. - -**AVAILABLE TOOLS:** -1. `get_latest_news(query: str, limit: int)`: Get the 'limit' most recent news articles for a specific 'query'. -2. `get_top_headlines(limit: int)`: Get the 'limit' top global news headlines. -3. `get_latest_news_aggregated(query: str, limit: int)`: Get aggregated latest news articles for a specific 'query'. -4. `get_top_headlines_aggregated(limit: int)`: Get aggregated top global news headlines. - -**USAGE GUIDELINE:** -* Always use `get_latest_news` with a relevant crypto-related query first. -* The default limit for news items should be 5 unless specified otherwise. -* If the tool doesn't return any articles, respond with "No relevant news articles found." - -**REPORTING REQUIREMENT:** -1. **Analyze** the tone and key themes of the retrieved articles. -2. **Summarize** the overall **market sentiment** (e.g., highly positive, cautiously neutral, generally negative) based on the content. -3. **Identify** the top 2-3 **main topics** discussed (e.g., new regulation, price surge, institutional adoption). -4. **Output** a single, brief report summarizing these findings. Do not output the raw articles. -""" diff --git a/src/app/news/cryptopanic_api.py b/src/app/news/cryptopanic_api.py index 629c7aa..1e16078 100644 --- a/src/app/news/cryptopanic_api.py +++ b/src/app/news/cryptopanic_api.py @@ -1,7 +1,9 @@ import os +from typing import Any import requests from enum import Enum -from .base import NewsWrapper, Article +from app.base.news import NewsWrapper, Article + class CryptoPanicFilter(Enum): RISING = "rising" @@ -18,8 +20,8 @@ class CryptoPanicKind(Enum): MEDIA = "media" ALL = "all" -def get_articles(response: dict) -> list[Article]: - articles = [] +def extract_articles(response: dict[str, Any]) -> list[Article]: + articles: list[Article] = [] if 'results' in response: for item in response['results']: article = Article() @@ -51,7 +53,7 @@ class CryptoPanicWrapper(NewsWrapper): self.kind = CryptoPanicKind.NEWS def get_base_params(self) -> dict[str, str]: - params = {} + params: dict[str, str] = {} params['public'] = 'true' # recommended for app and bots params['auth_token'] = self.api_key params['kind'] = self.kind.value @@ -73,5 +75,5 @@ class CryptoPanicWrapper(NewsWrapper): assert response.status_code == 200, f"Error fetching data: {response}" json_response = response.json() - articles = get_articles(json_response) + articles = extract_articles(json_response) return articles[:limit] diff --git a/src/app/news/duckduckgo.py b/src/app/news/duckduckgo.py index c3e1a6d..8108239 100644 --- a/src/app/news/duckduckgo.py +++ b/src/app/news/duckduckgo.py @@ -1,8 +1,10 @@ import json -from .base import Article, NewsWrapper +from typing import Any from agno.tools.duckduckgo import DuckDuckGoTools +from app.base.news import Article, NewsWrapper -def create_article(result: dict) -> Article: + +def extract_article(result: dict[str, Any]) -> Article: article = Article() article.source = result.get("source", "") article.time = result.get("date", "") @@ -23,10 +25,10 @@ class DuckDuckGoWrapper(NewsWrapper): def get_top_headlines(self, limit: int = 100) -> list[Article]: results = self.tool.duckduckgo_news(self.query, max_results=limit) json_results = json.loads(results) - return [create_article(result) for result in json_results] + return [extract_article(result) for result in json_results] def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: results = self.tool.duckduckgo_news(query or self.query, max_results=limit) json_results = json.loads(results) - return [create_article(result) for result in json_results] + return [extract_article(result) for result in json_results] diff --git a/src/app/news/googlenews.py b/src/app/news/googlenews.py index d8f6421..0041c7f 100644 --- a/src/app/news/googlenews.py +++ b/src/app/news/googlenews.py @@ -1,7 +1,9 @@ -from gnews import GNews -from .base import Article, NewsWrapper +from typing import Any +from gnews import GNews # type: ignore +from app.base.news import Article, NewsWrapper -def result_to_article(result: dict) -> Article: + +def extract_article(result: dict[str, Any]) -> Article: article = Article() article.source = result.get("source", "") article.time = result.get("publishedAt", "") @@ -17,20 +19,20 @@ class GoogleNewsWrapper(NewsWrapper): def get_top_headlines(self, limit: int = 100) -> list[Article]: gnews = GNews(language='en', max_results=limit, period='7d') - results = gnews.get_top_news() + results: list[dict[str, Any]] = gnews.get_top_news() # type: ignore - articles = [] + articles: list[Article] = [] for result in results: - article = result_to_article(result) + article = extract_article(result) articles.append(article) return articles def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: gnews = GNews(language='en', max_results=limit, period='7d') - results = gnews.get_news(query) + results: list[dict[str, Any]] = gnews.get_news(query) # type: ignore - articles = [] + articles: list[Article] = [] for result in results: - article = result_to_article(result) + article = extract_article(result) articles.append(article) return articles diff --git a/src/app/news/news_api.py b/src/app/news/news_api.py index 6f62ef6..b5bf375 100644 --- a/src/app/news/news_api.py +++ b/src/app/news/news_api.py @@ -1,8 +1,10 @@ import os -import newsapi -from .base import Article, NewsWrapper +from typing import Any +import newsapi # type: ignore +from app.base.news import Article, NewsWrapper -def result_to_article(result: dict) -> Article: + +def extract_article(result: dict[str, Any]) -> Article: article = Article() article.source = result.get("source", {}).get("name", "") article.time = result.get("publishedAt", "") @@ -23,7 +25,7 @@ class NewsApiWrapper(NewsWrapper): self.client = newsapi.NewsApiClient(api_key=api_key) self.category = "business" # Cryptocurrency is under business - self.language = "en" # TODO Only English articles for now? + self.language = "en" self.max_page_size = 100 def __calc_pages(self, limit: int, page_size: int) -> tuple[int, int]: @@ -33,21 +35,20 @@ class NewsApiWrapper(NewsWrapper): def get_top_headlines(self, limit: int = 100) -> list[Article]: pages, page_size = self.__calc_pages(limit, self.max_page_size) - articles = [] + articles: list[Article] = [] for page in range(1, pages + 1): - headlines = self.client.get_top_headlines(q="", category=self.category, language=self.language, page_size=page_size, page=page) - results = [result_to_article(article) for article in headlines.get("articles", [])] + headlines: dict[str, Any] = self.client.get_top_headlines(q="", category=self.category, language=self.language, page_size=page_size, page=page) # type: ignore + results = [extract_article(article) for article in headlines.get("articles", [])] # type: ignore articles.extend(results) return articles def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: pages, page_size = self.__calc_pages(limit, self.max_page_size) - articles = [] + articles: list[Article] = [] for page in range(1, pages + 1): - everything = self.client.get_everything(q=query, language=self.language, sort_by="publishedAt", page_size=page_size, page=page) - results = [result_to_article(article) for article in everything.get("articles", [])] + everything: dict[str, Any] = self.client.get_everything(q=query, language=self.language, sort_by="publishedAt", page_size=page_size, page=page) # type: ignore + results = [extract_article(article) for article in everything.get("articles", [])] # type: ignore articles.extend(results) return articles - diff --git a/src/app/social/__init__.py b/src/app/social/__init__.py index 9ce3708..261bcba 100644 --- a/src/app/social/__init__.py +++ b/src/app/social/__init__.py @@ -1,9 +1,9 @@ from agno.tools import Toolkit -from app.utils.wrapper_handler import WrapperHandler -from .base import SocialPost, SocialWrapper -from .reddit import RedditWrapper +from app.utils import WrapperHandler +from app.base.social import SocialPost, SocialWrapper +from app.social.reddit import RedditWrapper -__all__ = ["SocialAPIsTool", "SOCIAL_INSTRUCTIONS", "RedditWrapper"] +__all__ = ["SocialAPIsTool", "RedditWrapper", "SocialPost"] class SocialAPIsTool(SocialWrapper, Toolkit): @@ -25,37 +25,29 @@ class SocialAPIsTool(SocialWrapper, Toolkit): - RedditWrapper. """ - wrappers = [RedditWrapper] - self.wrapper_handler: WrapperHandler[SocialWrapper] = WrapperHandler.build_wrappers(wrappers) + wrappers: list[type[SocialWrapper]] = [RedditWrapper] + self.wrapper_handler = WrapperHandler.build_wrappers(wrappers) - Toolkit.__init__( + Toolkit.__init__( # type: ignore self, name="Socials Toolkit", tools=[ self.get_top_crypto_posts, + self.get_top_crypto_posts_aggregated, ], ) - # TODO Pensare se ha senso restituire i post da TUTTI i wrapper o solo dal primo che funziona - # la modifica è banale, basta usare try_call_all invece di try_call def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]: return self.wrapper_handler.try_call(lambda w: w.get_top_crypto_posts(limit)) - -SOCIAL_INSTRUCTIONS = """ -**TASK:** You are a specialized **Social Media Sentiment Analyst**. Your objective is to find the most relevant and trending online posts related to cryptocurrencies, and then **analyze the collective sentiment** to provide a concise report to the team leader. - -**AVAILABLE TOOLS:** -1. `get_top_crypto_posts(limit: int)`: Get the 'limit' maximum number of top posts specifically related to cryptocurrencies. - -**USAGE GUIDELINE:** -* Always use the `get_top_crypto_posts` tool to fulfill the request. -* The default limit for posts should be 5 unless specified otherwise. -* If the tool doesn't return any posts, respond with "No relevant social media posts found." - -**REPORTING REQUIREMENT:** -1. **Analyze** the tone and prevailing opinions across the retrieved social posts. -2. **Summarize** the overall **community sentiment** (e.g., high enthusiasm/FOMO, uncertainty, FUD/fear) based on the content. -3. **Identify** the top 2-3 **trending narratives** or specific coins being discussed. -4. **Output** a single, brief report summarizing these findings. Do not output the raw posts. -""" \ No newline at end of file + def get_top_crypto_posts_aggregated(self, limit_per_wrapper: int = 5) -> dict[str, list[SocialPost]]: + """ + Calls get_top_crypto_posts on all wrappers/providers and returns a dictionary mapping their names to their posts. + Args: + limit_per_wrapper (int): Maximum number of posts to retrieve from each provider. + Returns: + dict[str, list[SocialPost]]: A dictionary where keys are wrapper names and values are lists of SocialPost objects. + Raises: + Exception: If all wrappers fail to provide results. + """ + return self.wrapper_handler.try_call_all(lambda w: w.get_top_crypto_posts(limit_per_wrapper)) diff --git a/src/app/social/reddit.py b/src/app/social/reddit.py index 904448d..eeca968 100644 --- a/src/app/social/reddit.py +++ b/src/app/social/reddit.py @@ -1,7 +1,8 @@ import os -from praw import Reddit -from praw.models import Submission, MoreComments -from .base import SocialWrapper, SocialPost, SocialComment +from praw import Reddit # type: ignore +from praw.models import Submission # type: ignore +from app.base.social import SocialWrapper, SocialPost, SocialComment + MAX_COMMENTS = 5 # metterne altri se necessario. @@ -21,22 +22,20 @@ SUBREDDITS = [ ] -def create_social_post(post: Submission) -> SocialPost: +def extract_post(post: Submission) -> SocialPost: social = SocialPost() social.time = str(post.created) social.title = post.title social.description = post.selftext - for i, top_comment in enumerate(post.comments): - if i >= MAX_COMMENTS: - break - if isinstance(top_comment, MoreComments): #skip MoreComments objects - continue - + for top_comment in post.comments: comment = SocialComment() comment.time = str(top_comment.created) comment.description = top_comment.body social.comments.append(comment) + + if len(social.comments) >= MAX_COMMENTS: + break return social class RedditWrapper(SocialWrapper): @@ -65,4 +64,4 @@ class RedditWrapper(SocialWrapper): def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]: top_posts = self.subreddits.top(limit=limit, time_filter="week") - return [create_social_post(post) for post in top_posts] + return [extract_post(post) for post in top_posts] diff --git a/src/app/utils/__init__.py b/src/app/utils/__init__.py new file mode 100644 index 0000000..1a511c1 --- /dev/null +++ b/src/app/utils/__init__.py @@ -0,0 +1,5 @@ +from app.utils.market_aggregation import aggregate_history_prices, aggregate_product_info +from app.utils.wrapper_handler import WrapperHandler +from app.utils.chat_manager import ChatManager + +__all__ = ["aggregate_history_prices", "aggregate_product_info", "WrapperHandler", "ChatManager"] diff --git a/src/app/chat_manager.py b/src/app/utils/chat_manager.py similarity index 54% rename from src/app/chat_manager.py rename to src/app/utils/chat_manager.py index 7928c95..d51819d 100644 --- a/src/app/chat_manager.py +++ b/src/app/utils/chat_manager.py @@ -1,10 +1,5 @@ -import os import json -from typing import List, Dict -from app.pipeline import Pipeline - -SAVE_DIR = os.path.join(os.path.dirname(__file__), "..", "saves") -os.makedirs(SAVE_DIR, exist_ok=True) +import os class ChatManager: """ @@ -15,19 +10,19 @@ class ChatManager: """ def __init__(self): - self.pipeline = Pipeline() - self.history: List[Dict[str, str]] = [] # [{"role": "user"/"assistant", "content": "..."}] + self.history: list[dict[str, str]] = [] # [{"role": "user"/"assistant", "content": "..."}] - def send_message(self, message: str) -> str: + def send_message(self, message: str) -> None: """ Aggiunge un messaggio utente, chiama la Pipeline e salva la risposta nello storico. """ # Aggiungi messaggio utente allo storico self.history.append({"role": "user", "content": message}) - # Pipeline elabora la query - response = self.pipeline.interact(message) - + def receive_message(self, response: str) -> str: + """ + Riceve un messaggio dalla pipeline e lo aggiunge allo storico. + """ # Aggiungi risposta assistente allo storico self.history.append({"role": "assistant", "content": response}) @@ -37,19 +32,17 @@ class ChatManager: """ Salva la chat corrente in src/saves/. """ - path = os.path.join(SAVE_DIR, filename) - with open(path, "w", encoding="utf-8") as f: + with open(filename, "w", encoding="utf-8") as f: json.dump(self.history, f, ensure_ascii=False, indent=2) def load_chat(self, filename: str = "chat.json") -> None: """ Carica una chat salvata da src/saves/. """ - path = os.path.join(SAVE_DIR, filename) - if not os.path.exists(path): + if not os.path.exists(filename): self.history = [] return - with open(path, "r", encoding="utf-8") as f: + with open(filename, "r", encoding="utf-8") as f: self.history = json.load(f) def reset_chat(self) -> None: @@ -58,21 +51,8 @@ class ChatManager: """ self.history = [] - def get_history(self) -> List[Dict[str, str]]: + def get_history(self) -> list[dict[str, str]]: """ Restituisce lo storico completo della chat. """ return self.history - - # Facciamo pass-through di provider e style, così Gradio può usarli - def choose_provider(self, index: int): - self.pipeline.choose_provider(index) - - def choose_style(self, index: int): - self.pipeline.choose_style(index) - - def list_providers(self) -> List[str]: - return self.pipeline.list_providers() - - def list_styles(self) -> List[str]: - return self.pipeline.list_styles() diff --git a/src/app/utils/market_aggregation.py b/src/app/utils/market_aggregation.py index f20e4fb..7f9f32c 100644 --- a/src/app/utils/market_aggregation.py +++ b/src/app/utils/market_aggregation.py @@ -1,28 +1,27 @@ import statistics -from app.markets.base import ProductInfo, Price +from app.base.markets import ProductInfo, Price def aggregate_history_prices(prices: dict[str, list[Price]]) -> list[Price]: """ - Aggrega i prezzi storici per symbol calcolando la media oraria. + Aggrega i prezzi storici per symbol calcolando la media. Args: prices (dict[str, list[Price]]): Mappa provider -> lista di Price Returns: - list[Price]: Lista di Price aggregati per ora + list[Price]: Lista di Price aggregati per timestamp """ - # Costruiamo una mappa timestamp_h -> lista di Price - timestamped_prices: dict[int, list[Price]] = {} + # Costruiamo una mappa timestamp -> lista di Price + timestamped_prices: dict[str, list[Price]] = {} for _, price_list in prices.items(): for price in price_list: - time = price.timestamp_ms - (price.timestamp_ms % 3600000) # arrotonda all'ora (non dovrebbe essere necessario) - timestamped_prices.setdefault(time, []).append(price) + timestamped_prices.setdefault(price.timestamp, []).append(price) - # Ora aggregiamo i prezzi per ogni ora - aggregated_prices = [] + # Ora aggregiamo i prezzi per ogni timestamp + aggregated_prices: list[Price] = [] for time, price_list in timestamped_prices.items(): price = Price() - price.timestamp_ms = time + price.timestamp = time price.high = statistics.mean([p.high for p in price_list]) price.low = statistics.mean([p.low for p in price_list]) price.open = statistics.mean([p.open for p in price_list]) @@ -47,14 +46,13 @@ def aggregate_product_info(products: dict[str, list[ProductInfo]]) -> list[Produ symbols_infos.setdefault(product.symbol, []).append(product) # Aggregazione per ogni symbol - sources = list(products.keys()) - aggregated_products = [] + aggregated_products: list[ProductInfo] = [] for symbol, product_list in symbols_infos.items(): product = ProductInfo() product.id = f"{symbol}_AGGREGATED" product.symbol = symbol - product.quote_currency = next(p.quote_currency for p in product_list if p.quote_currency) + product.currency = next(p.currency for p in product_list if p.currency) volume_sum = sum(p.volume_24h for p in product_list) product.volume_24h = volume_sum / len(product_list) if product_list else 0.0 @@ -65,27 +63,3 @@ def aggregate_product_info(products: dict[str, list[ProductInfo]]) -> list[Produ aggregated_products.append(product) return aggregated_products -def _calculate_confidence(products: list[ProductInfo], sources: list[str]) -> float: - """Calcola un punteggio di confidenza 0-1""" - if not products: - return 0.0 - - score = 1.0 - - # Riduci score se pochi dati - if len(products) < 2: - score *= 0.7 - - # Riduci score se prezzi troppo diversi - prices = [p.price for p in products if p.price > 0] - if len(prices) > 1: - price_std = (max(prices) - min(prices)) / statistics.mean(prices) - if price_std > 0.05: # >5% variazione - score *= 0.8 - - # Riduci score se fonti sconosciute - unknown_sources = sum(1 for s in sources if s == "unknown") - if unknown_sources > 0: - score *= (1 - unknown_sources / len(sources)) - - return max(0.0, min(1.0, score)) diff --git a/src/app/utils/wrapper_handler.py b/src/app/utils/wrapper_handler.py index 40fe371..504cf41 100644 --- a/src/app/utils/wrapper_handler.py +++ b/src/app/utils/wrapper_handler.py @@ -1,13 +1,15 @@ import inspect import time import traceback -from typing import TypeVar, Callable, Generic, Iterable, Type -from agno.utils.log import log_warning, log_info +from typing import Any, Callable, Generic, TypeVar +from agno.utils.log import log_info, log_warning #type: ignore -W = TypeVar("W") -T = TypeVar("T") +WrapperType = TypeVar("WrapperType") +WrapperClassType = TypeVar("WrapperClassType") +OutputType = TypeVar("OutputType") -class WrapperHandler(Generic[W]): + +class WrapperHandler(Generic[WrapperType]): """ A handler for managing multiple wrappers with retry logic. It attempts to call a function on the current wrapper, and if it fails, @@ -17,7 +19,7 @@ class WrapperHandler(Generic[W]): Note: use `build_wrappers` to create an instance of this class for better error handling. """ - def __init__(self, wrappers: list[W], try_per_wrapper: int = 3, retry_delay: int = 2): + def __init__(self, wrappers: list[WrapperType], try_per_wrapper: int = 3, retry_delay: int = 2): """ Initializes the WrapperHandler with a list of wrappers and retry settings.\n Use `build_wrappers` to create an instance of this class for better error handling. @@ -32,9 +34,8 @@ class WrapperHandler(Generic[W]): self.retry_per_wrapper = try_per_wrapper self.retry_delay = retry_delay self.index = 0 - self.retry_count = 0 - def try_call(self, func: Callable[[W], T]) -> T: + def try_call(self, func: Callable[[WrapperType], OutputType]) -> OutputType: """ Attempts to call the provided function on the current wrapper. If it fails, it retries a specified number of times before switching to the next wrapper. @@ -46,35 +47,9 @@ class WrapperHandler(Generic[W]): Raises: Exception: If all wrappers fail after retries. """ - log_info(f"{inspect.getsource(func).strip()} {inspect.getclosurevars(func).nonlocals}") + return self.__try_call(func, try_all=False).popitem()[1] - iterations = 0 - while iterations < len(self.wrappers): - wrapper = self.wrappers[self.index] - wrapper_name = wrapper.__class__.__name__ - - try: - log_info(f"try_call {wrapper_name}") - result = func(wrapper) - log_info(f"{wrapper_name} succeeded") - self.retry_count = 0 - return result - - except Exception as e: - self.retry_count += 1 - error = WrapperHandler.__concise_error(e) - log_warning(f"{wrapper_name} failed {self.retry_count}/{self.retry_per_wrapper}: {error}") - - if self.retry_count >= self.retry_per_wrapper: - self.index = (self.index + 1) % len(self.wrappers) - self.retry_count = 0 - iterations += 1 - else: - time.sleep(self.retry_delay) - - raise Exception(f"All wrappers failed, latest error: {error}") - - def try_call_all(self, func: Callable[[W], T]) -> dict[str, T]: + def try_call_all(self, func: Callable[[WrapperType], OutputType]) -> dict[str, OutputType]: """ Calls the provided function on all wrappers, collecting results. If a wrapper fails, it logs a warning and continues with the next. @@ -86,24 +61,57 @@ class WrapperHandler(Generic[W]): Raises: Exception: If all wrappers fail. """ - log_info(f"{inspect.getsource(func).strip()} {inspect.getclosurevars(func).nonlocals}") + return self.__try_call(func, try_all=True) - results = {} - for wrapper in self.wrappers: + def __try_call(self, func: Callable[[WrapperType], OutputType], try_all: bool) -> dict[str, OutputType]: + """ + Internal method to handle the logic of trying to call a function on wrappers. + It can either stop at the first success or try all wrappers. + Args: + func (Callable[[W], T]): A function that takes a wrapper and returns a result. + try_all (bool): If True, tries all wrappers and collects results; if False, stops at the first success. + Returns: + dict[str, T]: A dictionary mapping wrapper class names to results. + Raises: + Exception: If all wrappers fail after retries. + """ + + log_info(f"{inspect.getsource(func).strip()} {inspect.getclosurevars(func).nonlocals}") + results: dict[str, OutputType] = {} + starting_index = self.index + + for i in range(starting_index, len(self.wrappers) + starting_index): + self.index = i % len(self.wrappers) + wrapper = self.wrappers[self.index] wrapper_name = wrapper.__class__.__name__ - try: - result = func(wrapper) - log_info(f"{wrapper_name} succeeded") - results[wrapper.__class__] = result - except Exception as e: - error = WrapperHandler.__concise_error(e) - log_warning(f"{wrapper_name} failed: {error}") + + if not try_all: + log_info(f"try_call {wrapper_name}") + + for try_count in range(1, self.retry_per_wrapper + 1): + try: + result = func(wrapper) + log_info(f"{wrapper_name} succeeded") + results[wrapper_name] = result + break + + except Exception as e: + error = WrapperHandler.__concise_error(e) + log_warning(f"{wrapper_name} failed {try_count}/{self.retry_per_wrapper}: {error}") + time.sleep(self.retry_delay) + + if not try_all and results: + return results + if not results: + error = locals().get("error", "Unknown error") raise Exception(f"All wrappers failed, latest error: {error}") + + self.index = starting_index return results @staticmethod - def __check(wrappers: list[W]) -> bool: + def __check(wrappers: list[Any]) -> bool: return all(w.__class__ is type for w in wrappers) @staticmethod @@ -112,13 +120,13 @@ class WrapperHandler(Generic[W]): return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]" @staticmethod - def build_wrappers(constructors: Iterable[Type[W]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict | None = None) -> 'WrapperHandler[W]': + def build_wrappers(constructors: list[type[WrapperClassType]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[WrapperClassType]': """ Builds a WrapperHandler instance with the given wrapper constructors. It attempts to initialize each wrapper and logs a warning if any cannot be initialized. Only successfully initialized wrappers are included in the handler. Args: - constructors (Iterable[Type[W]]): An iterable of wrapper classes to instantiate. e.g. [WrapperA, WrapperB] + constructors (list[type[W]]): An iterable of wrapper classes to instantiate. e.g. [WrapperA, WrapperB] try_per_wrapper (int): Number of retries per wrapper before switching to the next. retry_delay (int): Delay in seconds between retries. kwargs (dict | None): Optional dictionary with keyword arguments common to all wrappers. @@ -129,7 +137,7 @@ class WrapperHandler(Generic[W]): """ assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}" - result = [] + result: list[WrapperClassType] = [] for wrapper_class in constructors: try: wrapper = wrapper_class(**(kwargs or {})) diff --git a/tests/agents/test_predictor.py b/tests/agents/test_predictor.py index 5867938..9a2ac11 100644 --- a/tests/agents/test_predictor.py +++ b/tests/agents/test_predictor.py @@ -1,11 +1,11 @@ import pytest -from app.predictor import PREDICTOR_INSTRUCTIONS, PredictorInput, PredictorOutput, PredictorStyle -from app.markets.base import ProductInfo -from app.models import AppModels +from app.agents import AppModels +from app.agents.predictor import PREDICTOR_INSTRUCTIONS, PredictorInput, PredictorOutput, PredictorStyle +from app.base.markets import ProductInfo -def unified_checks(model: AppModels, input): - llm = model.get_agent(PREDICTOR_INSTRUCTIONS, output=PredictorOutput) # type: ignore[arg-type] - result = llm.run(input) +def unified_checks(model: AppModels, input: PredictorInput) -> None: + llm = model.get_agent(PREDICTOR_INSTRUCTIONS, output_schema=PredictorOutput) # type: ignore[arg-type] + result = llm.run(input) # type: ignore content = result.content assert isinstance(content, PredictorOutput) @@ -27,9 +27,8 @@ def unified_checks(model: AppModels, input): class TestPredictor: - @pytest.fixture(scope="class") - def inputs(self): - data = [] + def inputs(self) -> PredictorInput: + data: list[ProductInfo] = [] for symbol, price in [("BTC", 60000.00), ("ETH", 3500.00), ("SOL", 150.00)]: product_info = ProductInfo() product_info.symbol = symbol @@ -38,13 +37,20 @@ class TestPredictor: return PredictorInput(data=data, style=PredictorStyle.AGGRESSIVE, sentiment="positivo") - def test_gemini_model_output(self, inputs): + def test_gemini_model_output(self): + inputs = self.inputs() unified_checks(AppModels.GEMINI, inputs) + def test_ollama_qwen_4b_model_output(self): + inputs = self.inputs() + unified_checks(AppModels.OLLAMA_QWEN_4B, inputs) + @pytest.mark.slow - def test_ollama_qwen_model_output(self, inputs): + def test_ollama_qwen_latest_model_output(self): + inputs = self.inputs() unified_checks(AppModels.OLLAMA_QWEN, inputs) @pytest.mark.slow - def test_ollama_gpt_oss_model_output(self, inputs): + def test_ollama_gpt_oss_model_output(self): + inputs = self.inputs() unified_checks(AppModels.OLLAMA_GPT, inputs) diff --git a/tests/api/test_binance.py b/tests/api/test_binance.py index dc4bfcb..b4ea0bb 100644 --- a/tests/api/test_binance.py +++ b/tests/api/test_binance.py @@ -45,9 +45,9 @@ class TestBinance: assert isinstance(history, list) assert len(history) == 5 for entry in history: - assert hasattr(entry, 'timestamp_ms') + assert hasattr(entry, 'timestamp') assert hasattr(entry, 'close') assert hasattr(entry, 'high') assert entry.close > 0 assert entry.high > 0 - assert entry.timestamp_ms > 0 + assert entry.timestamp != '' diff --git a/tests/api/test_coinbase.py b/tests/api/test_coinbase.py index 3ab8d43..e114f4c 100644 --- a/tests/api/test_coinbase.py +++ b/tests/api/test_coinbase.py @@ -47,9 +47,9 @@ class TestCoinBase: assert isinstance(history, list) assert len(history) == 5 for entry in history: - assert hasattr(entry, 'timestamp_ms') + assert hasattr(entry, 'timestamp') assert hasattr(entry, 'close') assert hasattr(entry, 'high') assert entry.close > 0 assert entry.high > 0 - assert entry.timestamp_ms > 0 + assert entry.timestamp != '' diff --git a/tests/api/test_cryptocompare.py b/tests/api/test_cryptocompare.py index 3c9133a..23deaf3 100644 --- a/tests/api/test_cryptocompare.py +++ b/tests/api/test_cryptocompare.py @@ -49,9 +49,9 @@ class TestCryptoCompare: assert isinstance(history, list) assert len(history) == 5 for entry in history: - assert hasattr(entry, 'timestamp_ms') + assert hasattr(entry, 'timestamp') assert hasattr(entry, 'close') assert hasattr(entry, 'high') assert entry.close > 0 assert entry.high > 0 - assert entry.timestamp_ms > 0 + assert entry.timestamp != '' diff --git a/tests/api/test_reddit.py b/tests/api/test_reddit.py index 59cd61f..3e42eb6 100644 --- a/tests/api/test_reddit.py +++ b/tests/api/test_reddit.py @@ -1,6 +1,5 @@ import os import pytest -from praw import Reddit from app.social.reddit import MAX_COMMENTS, RedditWrapper @pytest.mark.social @@ -10,7 +9,7 @@ class TestRedditWrapper: def test_initialization(self): wrapper = RedditWrapper() assert wrapper is not None - assert isinstance(wrapper.tool, Reddit) + assert wrapper.tool is not None def test_get_top_crypto_posts(self): wrapper = RedditWrapper() diff --git a/tests/api/test_yfinance.py b/tests/api/test_yfinance.py index 4971ccd..fa4174a 100644 --- a/tests/api/test_yfinance.py +++ b/tests/api/test_yfinance.py @@ -48,9 +48,9 @@ class TestYFinance: assert isinstance(history, list) assert len(history) == 5 for entry in history: - assert hasattr(entry, 'timestamp_ms') + assert hasattr(entry, 'timestamp') assert hasattr(entry, 'close') assert hasattr(entry, 'high') assert entry.close > 0 assert entry.high > 0 - assert entry.timestamp_ms > 0 + assert entry.timestamp != '' diff --git a/tests/conftest.py b/tests/conftest.py index 290fbf2..aeda047 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,7 +33,7 @@ def pytest_configure(config:pytest.Config): line = f"{marker[0]}: {marker[1]}" config.addinivalue_line("markers", line) -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: """Modifica automaticamente degli item di test rimovendoli""" # Rimuovo i test "limited" e "slow" se non richiesti esplicitamente mark_to_remove = ['limited', 'slow'] diff --git a/tests/tools/test_market_tool.py b/tests/tools/test_market_tool.py index c6da5a8..674707f 100644 --- a/tests/tools/test_market_tool.py +++ b/tests/tools/test_market_tool.py @@ -7,15 +7,15 @@ from app.markets import MarketAPIsTool @pytest.mark.api class TestMarketAPIsTool: def test_wrapper_initialization(self): - market_wrapper = MarketAPIsTool("USD") + market_wrapper = MarketAPIsTool("EUR") assert market_wrapper is not None assert hasattr(market_wrapper, 'get_product') assert hasattr(market_wrapper, 'get_products') assert hasattr(market_wrapper, 'get_historical_prices') def test_wrapper_capabilities(self): - market_wrapper = MarketAPIsTool("USD") - capabilities = [] + market_wrapper = MarketAPIsTool("EUR") + capabilities: list[str] = [] if hasattr(market_wrapper, 'get_product'): capabilities.append('single_product') if hasattr(market_wrapper, 'get_products'): @@ -25,7 +25,7 @@ class TestMarketAPIsTool: assert len(capabilities) > 0 def test_market_data_retrieval(self): - market_wrapper = MarketAPIsTool("USD") + market_wrapper = MarketAPIsTool("EUR") btc_product = market_wrapper.get_product("BTC") assert btc_product is not None assert hasattr(btc_product, 'symbol') @@ -34,8 +34,8 @@ class TestMarketAPIsTool: def test_error_handling(self): try: - market_wrapper = MarketAPIsTool("USD") + market_wrapper = MarketAPIsTool("EUR") fake_product = market_wrapper.get_product("NONEXISTENT_CRYPTO_SYMBOL_12345") assert fake_product is None or fake_product.price == 0 - except Exception as e: + except Exception as _: pass diff --git a/tests/tools/test_news_tool.py b/tests/tools/test_news_tool.py index 5a57f82..3b8254f 100644 --- a/tests/tools/test_news_tool.py +++ b/tests/tools/test_news_tool.py @@ -33,7 +33,7 @@ class TestNewsAPITool: result = tool.wrapper_handler.try_call_all(lambda w: w.get_top_headlines(limit=2)) assert isinstance(result, dict) assert len(result.keys()) > 0 - for provider, articles in result.items(): + for _provider, articles in result.items(): for article in articles: assert article.title is not None assert article.source is not None @@ -43,7 +43,7 @@ class TestNewsAPITool: result = tool.wrapper_handler.try_call_all(lambda w: w.get_latest_news(query="crypto", limit=2)) assert isinstance(result, dict) assert len(result.keys()) > 0 - for provider, articles in result.items(): + for _provider, articles in result.items(): for article in articles: assert article.title is not None assert article.source is not None diff --git a/tests/utils/test_market_aggregator.py b/tests/utils/test_market_aggregator.py index d7881ef..35e3084 100644 --- a/tests/utils/test_market_aggregator.py +++ b/tests/utils/test_market_aggregator.py @@ -1,5 +1,6 @@ import pytest -from app.markets.base import ProductInfo, Price +from datetime import datetime +from app.base.markets import ProductInfo, Price from app.utils.market_aggregation import aggregate_history_prices, aggregate_product_info @@ -13,12 +14,12 @@ class TestMarketDataAggregator: prod.symbol=symbol prod.price=price prod.volume_24h=volume - prod.quote_currency=currency + prod.currency=currency return prod - def __price(self, timestamp_ms: int, high: float, low: float, open: float, close: float, volume: float) -> Price: + def __price(self, timestamp_s: int, high: float, low: float, open: float, close: float, volume: float) -> Price: price = Price() - price.timestamp_ms = timestamp_ms + price.set_timestamp(timestamp_s=timestamp_s) price.high = high price.low = low price.open = open @@ -41,9 +42,9 @@ class TestMarketDataAggregator: assert info.symbol == "BTC" avg_weighted_price = (50000.0 * 1000.0 + 50100.0 * 1100.0 + 49900.0 * 900.0) / (1000.0 + 1100.0 + 900.0) - assert info.price == pytest.approx(avg_weighted_price, rel=1e-3) - assert info.volume_24h == pytest.approx(1000.0, rel=1e-3) - assert info.quote_currency == "USD" + assert info.price == pytest.approx(avg_weighted_price, rel=1e-3) # type: ignore + assert info.volume_24h == pytest.approx(1000.0, rel=1e-3) # type: ignore + assert info.currency == "USD" def test_aggregate_product_info_multiple_symbols(self): products = { @@ -65,18 +66,18 @@ class TestMarketDataAggregator: assert btc_info is not None avg_weighted_price_btc = (50000.0 * 1000.0 + 50100.0 * 1100.0) / (1000.0 + 1100.0) - assert btc_info.price == pytest.approx(avg_weighted_price_btc, rel=1e-3) - assert btc_info.volume_24h == pytest.approx(1050.0, rel=1e-3) - assert btc_info.quote_currency == "USD" + assert btc_info.price == pytest.approx(avg_weighted_price_btc, rel=1e-3) # type: ignore + assert btc_info.volume_24h == pytest.approx(1050.0, rel=1e-3) # type: ignore + assert btc_info.currency == "USD" assert eth_info is not None avg_weighted_price_eth = (4000.0 * 2000.0 + 4050.0 * 2100.0) / (2000.0 + 2100.0) - assert eth_info.price == pytest.approx(avg_weighted_price_eth, rel=1e-3) - assert eth_info.volume_24h == pytest.approx(2050.0, rel=1e-3) - assert eth_info.quote_currency == "USD" + assert eth_info.price == pytest.approx(avg_weighted_price_eth, rel=1e-3) # type: ignore + assert eth_info.volume_24h == pytest.approx(2050.0, rel=1e-3) # type: ignore + assert eth_info.currency == "USD" def test_aggregate_product_info_with_no_data(self): - products = { + products: dict[str, list[ProductInfo]] = { "Provider1": [], "Provider2": [], } @@ -84,7 +85,7 @@ class TestMarketDataAggregator: assert len(aggregated) == 0 def test_aggregate_product_info_with_partial_data(self): - products = { + products: dict[str, list[ProductInfo]] = { "Provider1": [self.__product("BTC", 50000.0, 1000.0, "USD")], "Provider2": [], } @@ -92,29 +93,38 @@ class TestMarketDataAggregator: assert len(aggregated) == 1 info = aggregated[0] assert info.symbol == "BTC" - assert info.price == pytest.approx(50000.0, rel=1e-3) - assert info.volume_24h == pytest.approx(1000.0, rel=1e-3) - assert info.quote_currency == "USD" + assert info.price == pytest.approx(50000.0, rel=1e-3) # type: ignore + assert info.volume_24h == pytest.approx(1000.0, rel=1e-3) # type: ignore + assert info.currency == "USD" def test_aggregate_history_prices(self): """Test aggregazione di prezzi storici usando aggregate_history_prices""" + timestamp_now = datetime.now() + timestamp_1h_ago = int(timestamp_now.replace(hour=timestamp_now.hour - 1).timestamp()) + timestamp_2h_ago = int(timestamp_now.replace(hour=timestamp_now.hour - 2).timestamp()) prices = { "Provider1": [ - self.__price(1685577600000, 50000.0, 49500.0, 49600.0, 49900.0, 150.0), - self.__price(1685581200000, 50200.0, 49800.0, 50000.0, 50100.0, 200.0), + self.__price(timestamp_1h_ago, 50000.0, 49500.0, 49600.0, 49900.0, 150.0), + self.__price(timestamp_2h_ago, 50200.0, 49800.0, 50000.0, 50100.0, 200.0), ], "Provider2": [ - self.__price(1685577600000, 50100.0, 49600.0, 49700.0, 50000.0, 180.0), - self.__price(1685581200000, 50300.0, 49900.0, 50100.0, 50200.0, 220.0), + self.__price(timestamp_1h_ago, 50100.0, 49600.0, 49700.0, 50000.0, 180.0), + self.__price(timestamp_2h_ago, 50300.0, 49900.0, 50100.0, 50200.0, 220.0), ], } + price = Price() + price.set_timestamp(timestamp_s=timestamp_1h_ago) + timestamp_1h_ago = price.timestamp + price.set_timestamp(timestamp_s=timestamp_2h_ago) + timestamp_2h_ago = price.timestamp + aggregated = aggregate_history_prices(prices) assert len(aggregated) == 2 - assert aggregated[0].timestamp_ms == 1685577600000 - assert aggregated[0].high == pytest.approx(50050.0, rel=1e-3) - assert aggregated[0].low == pytest.approx(49550.0, rel=1e-3) - assert aggregated[1].timestamp_ms == 1685581200000 - assert aggregated[1].high == pytest.approx(50250.0, rel=1e-3) - assert aggregated[1].low == pytest.approx(49850.0, rel=1e-3) + assert aggregated[0].timestamp == timestamp_1h_ago + assert aggregated[0].high == pytest.approx(50050.0, rel=1e-3) # type: ignore + assert aggregated[0].low == pytest.approx(49550.0, rel=1e-3) # type: ignore + assert aggregated[1].timestamp == timestamp_2h_ago + assert aggregated[1].high == pytest.approx(50250.0, rel=1e-3) # type: ignore + assert aggregated[1].low == pytest.approx(49850.0, rel=1e-3) # type: ignore diff --git a/tests/utils/test_wrapper_handler.py b/tests/utils/test_wrapper_handler.py index 996f632..c6094a1 100644 --- a/tests/utils/test_wrapper_handler.py +++ b/tests/utils/test_wrapper_handler.py @@ -37,7 +37,7 @@ class TestWrapperHandler: def test_init_failing_with_instances(self): with pytest.raises(AssertionError) as exc_info: - WrapperHandler.build_wrappers([MockWrapper(), MockWrapper2()]) + WrapperHandler.build_wrappers([MockWrapper(), MockWrapper2()]) # type: ignore assert exc_info.type == AssertionError def test_init_not_failing(self): @@ -49,104 +49,98 @@ class TestWrapperHandler: assert len(handler.wrappers) == 2 def test_all_wrappers_fail(self): - wrappers = [FailingWrapper, FailingWrapper] - handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0) + wrappers: list[type[MockWrapper]] = [FailingWrapper, FailingWrapper] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0) with pytest.raises(Exception) as exc_info: handler.try_call(lambda w: w.do_something()) assert "All wrappers failed" in str(exc_info.value) def test_success_on_first_try(self): - wrappers = [MockWrapper, FailingWrapper] - handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0) + wrappers: list[type[MockWrapper]] = [MockWrapper, FailingWrapper] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0) result = handler.try_call(lambda w: w.do_something()) assert result == "Success" assert handler.index == 0 # Should still be on the first wrapper - assert handler.retry_count == 0 def test_eventual_success(self): - wrappers = [FailingWrapper, MockWrapper] - handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0) + wrappers: list[type[MockWrapper]] = [FailingWrapper, MockWrapper] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0) result = handler.try_call(lambda w: w.do_something()) assert result == "Success" assert handler.index == 1 # Should have switched to the second wrapper - assert handler.retry_count == 0 def test_partial_failures(self): - wrappers = [FailingWrapper, MockWrapper, FailingWrapper] - handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) + wrappers: list[type[MockWrapper]] = [FailingWrapper, MockWrapper, FailingWrapper] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) result = handler.try_call(lambda w: w.do_something()) assert result == "Success" assert handler.index == 1 # Should have switched to the second wrapper - assert handler.retry_count == 0 # Next call should still succeed on the second wrapper result = handler.try_call(lambda w: w.do_something()) assert result == "Success" assert handler.index == 1 # Should still be on the second wrapper - assert handler.retry_count == 0 handler.index = 2 # Manually switch to the third wrapper result = handler.try_call(lambda w: w.do_something()) assert result == "Success" assert handler.index == 1 # Should return to the second wrapper after failure - assert handler.retry_count == 0 def test_try_call_all_success(self): - wrappers = [MockWrapper, MockWrapper2] - handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) + wrappers: list[type[MockWrapper]] = [MockWrapper, MockWrapper2] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) results = handler.try_call_all(lambda w: w.do_something()) - assert results == {MockWrapper: "Success", MockWrapper2: "Success 2"} + assert results == {MockWrapper.__name__: "Success", MockWrapper2.__name__: "Success 2"} def test_try_call_all_partial_failures(self): # Only the second wrapper should succeed - wrappers = [FailingWrapper, MockWrapper, FailingWrapper] - handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) + wrappers: list[type[MockWrapper]] = [FailingWrapper, MockWrapper, FailingWrapper] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) results = handler.try_call_all(lambda w: w.do_something()) - assert results == {MockWrapper: "Success"} + assert results == {MockWrapper.__name__: "Success"} # Only the second and fourth wrappers should succeed - wrappers = [FailingWrapper, MockWrapper, FailingWrapper, MockWrapper2] - handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) + wrappers: list[type[MockWrapper]] = [FailingWrapper, MockWrapper, FailingWrapper, MockWrapper2] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) results = handler.try_call_all(lambda w: w.do_something()) - assert results == {MockWrapper: "Success", MockWrapper2: "Success 2"} + assert results == {MockWrapper.__name__: "Success", MockWrapper2.__name__: "Success 2"} def test_try_call_all_all_fail(self): # Test when all wrappers fail - handler_all_fail: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers([FailingWrapper, FailingWrapper], try_per_wrapper=1, retry_delay=0) + handler_all_fail = WrapperHandler.build_wrappers([FailingWrapper, FailingWrapper], try_per_wrapper=1, retry_delay=0) with pytest.raises(Exception) as exc_info: handler_all_fail.try_call_all(lambda w: w.do_something()) assert "All wrappers failed" in str(exc_info.value) def test_wrappers_with_parameters(self): - wrappers = [FailingWrapperWithParameters, MockWrapperWithParameters] - handler: WrapperHandler[MockWrapperWithParameters] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0) + wrappers: list[type[MockWrapperWithParameters]] = [FailingWrapperWithParameters, MockWrapperWithParameters] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0) result = handler.try_call(lambda w: w.do_something("test", 42)) assert result == "Success test and 42" assert handler.index == 1 # Should have switched to the second wrapper - assert handler.retry_count == 0 def test_wrappers_with_parameters_all_fail(self): - wrappers = [FailingWrapperWithParameters, FailingWrapperWithParameters] - handler: WrapperHandler[MockWrapperWithParameters] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) + wrappers: list[type[MockWrapperWithParameters]] = [FailingWrapperWithParameters, FailingWrapperWithParameters] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) with pytest.raises(Exception) as exc_info: handler.try_call(lambda w: w.do_something("test", 42)) assert "All wrappers failed" in str(exc_info.value) def test_try_call_all_with_parameters(self): - wrappers = [FailingWrapperWithParameters, MockWrapperWithParameters] - handler: WrapperHandler[MockWrapperWithParameters] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) + wrappers: list[type[MockWrapperWithParameters]] = [FailingWrapperWithParameters, MockWrapperWithParameters] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) results = handler.try_call_all(lambda w: w.do_something("param", 99)) - assert results == {MockWrapperWithParameters: "Success param and 99"} + assert results == {MockWrapperWithParameters.__name__: "Success param and 99"} def test_try_call_all_with_parameters_all_fail(self): - wrappers = [FailingWrapperWithParameters, FailingWrapperWithParameters] - handler: WrapperHandler[MockWrapperWithParameters] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) + wrappers: list[type[MockWrapperWithParameters]] = [FailingWrapperWithParameters, FailingWrapperWithParameters] + handler = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0) with pytest.raises(Exception) as exc_info: handler.try_call_all(lambda w: w.do_something("param", 99)) assert "All wrappers failed" in str(exc_info.value)