From 4615ebe63ec632999ebab3df9d7a98798429af72 Mon Sep 17 00:00:00 2001 From: Berack96 Date: Sat, 27 Sep 2025 18:51:20 +0200 Subject: [PATCH] Pydantic - use Pydantic for input & output for models - update ToolAgent to utilize new model definitions - improve test cases for consistency --- src/app/agents/predictor.py | 82 +++++++++++----------------------- src/app/markets/base.py | 30 ++++++------- src/app/models.py | 37 +++++++-------- src/app/tool.py | 26 +++++------ tests/agents/test_predictor.py | 52 +++++++++++---------- 5 files changed, 101 insertions(+), 126 deletions(-) diff --git a/src/app/agents/predictor.py b/src/app/agents/predictor.py index fd9a8b3..e811846 100644 --- a/src/app/agents/predictor.py +++ b/src/app/agents/predictor.py @@ -1,81 +1,51 @@ -import json from enum import Enum from app.markets.base import ProductInfo +from pydantic import BaseModel, Field class PredictorStyle(Enum): CONSERVATIVE = "Conservativo" AGGRESSIVE = "Aggressivo" -# TODO (?) Change sentiment to a more structured format or merge it with data analysis (change then also the prompt) -def prepare_inputs(data: list[ProductInfo], style: PredictorStyle, sentiment: str) -> str: - return json.dumps({ - "data": [(product.symbol, f"{product.price:.2f}") for product in data], - "style": style.value, - "sentiment": sentiment - }) +class PredictorInput(BaseModel): + data: list[ProductInfo] = Field(..., description="Market data as a list of ProductInfo") + style: PredictorStyle = Field(..., description="Prediction style") + sentiment: str = Field(..., description="Aggregated sentiment from news and social analysis") -def instructions() -> str: - return """ -You are an **Allocation Algorithm (Crypto-Algo)**. Your sole objective is to process the input data and generate a strictly structured output, as specified. **You must not provide any explanations, conclusions, introductions, preambles, or comments that are not strictly required by the final format.** +class ItemPortfolio(BaseModel): + asset: str = Field(..., description="Name of the asset") + percentage: float = Field(..., description="Percentage allocation to the asset") + motivation: str = Field(..., description="Motivation for the allocation") -**CRITICAL INSTRUCTION: The final output MUST be a valid JSON object written entirely in Italian, following the structure below.** +class PredictorOutput(BaseModel): + strategy: str = Field(..., description="Concise operational strategy in Italian") + portfolio: list[ItemPortfolio] = Field(..., description="List of portfolio items with allocations") + +PREDICTOR_INSTRUCTIONS = """ +You are an **Allocation Algorithm (Crypto-Algo)** specialized in analyzing market data and sentiment to generate an investment strategy and a target portfolio. + +Your sole objective is to process the input data and generate the strictly structured output as required by the response format. **You MUST NOT provide introductions, preambles, explanations, conclusions, or any additional comments that are not strictly required.** ## Processing Instructions (Absolute Rule) -Analyze the Input provided in JSON format and generate the Output in two distinct sections. Your allocation strategy must be **derived exclusively from the "Logic Rule" corresponding to the requested *style*** and the *data* provided. **DO NOT** use external knowledge. +The allocation strategy must be **derived exclusively from the "Allocation Logic" corresponding to the requested *style*** and the provided market/sentiment data. **DO NOT** use external or historical knowledge. -## Data Input (JSON Format) -The input will be a single JSON block containing the following mandatory fields: - -1. **"data":** *Array of Arrays*. Market data. Format: `[[Asset_Name: String, Current_Price: String], ...]` - * *Example:* `[["BTC", "60000.00"], ["ETH", "3500.00"], ["SOL", "150.00"]]` -2. **"style":** *ENUM String (only "conservativo" or "aggressivo")*. Defines the risk approach. -3. **"sentiment":** *Descriptive String*. Summarizes market sentiment. - -## Allocation Logic Rules +## Allocation Logic ### "Aggressivo" Style (Aggressive) -* **Priority:** Maximum return (High Volatility accepted). +* **Priority:** Maximizing return (high volatility accepted). * **Focus:** Higher allocation to **non-BTC/ETH assets** with high momentum potential (Altcoins, mid/low-cap assets). -* **BTC/ETH:** Must form a base (anchor), but their allocation **must not exceed 50%** of the total portfolio. -* **Sentiment:** Use positive sentiment to increase allocation to high-risk assets. +* **BTC/ETH:** Must serve as a base (anchor), but their allocation **must not exceed 50%** of the total portfolio. +* **Sentiment:** Use positive sentiment to increase exposure to high-risk assets. ### "Conservativo" Style (Conservative) -* **Priority:** Capital preservation (Volatility minimized). +* **Priority:** Capital preservation (volatility minimized). * **Focus:** Major allocation to **BTC and/or ETH (Large-Cap Assets)**. * **BTC/ETH:** Their allocation **must be at least 70%** of the total portfolio. * **Altcoins:** Any allocations to non-BTC/ETH assets must be minimal (max 30% combined) and for assets that minimize speculative risk. * **Sentiment:** Use positive sentiment only as confirmation for exposure, avoiding reactions to excessive "FOMO" signals. -## Output Format Requirements (Strict JSON) +## Output Requirements (Content MUST be in Italian) -The Output **must be a single JSON object** with two keys: `"strategia"` and `"portafoglio"`. - -1. **"strategia":** *Stringa (massimo 5 frasi in Italiano)*. Una descrizione operativa concisa. -2. **"portafoglio":** *Array di Oggetti JSON*. La somma delle percentuali deve essere **esattamente 100%**. Ogni oggetto nell'array deve avere i seguenti campi (valori in Italiano): - * `"asset"`: Nome dell'Asset (es. "BTC"). - * `"percentuale"`: Percentuale di Allocazione (come numero intero o decimale, es. 45.0). - * `"motivazione"`: Stringa (massimo una frase) che giustifica l'allocazione. - -**THE OUTPUT MUST BE GENERATED BY FAITHFULLY COPYING THE FOLLOWING STRUCTURAL TEMPLATE (IN ITALIAN CONTENT, JSON FORMAT):** -{ - "strategia": "[Strategia sintetico-operativa in massimo 5 frasi...]", - "portafoglio": [ - { - "asset": "Asset_1", - "percentuale": X, - "motivazione": "[Massimo una frase chiara in Italiano]" - }, - { - "asset": "Asset_2", - "percentuale": Y, - "motivazione": "[Massimo una frase chiara in Italiano]" - }, - { - "asset": "Asset_3", - "percentuale": Z, - "motivazione": "[Massimo una frase chiara in Italiano]" - } - ] -} +1. **Strategy (strategy):** Must be a concise operational description **in Italian ("in Italiano")**, with a maximum of 5 sentences. +2. **Portfolio (portfolio):** The sum of all percentages must be **exactly 100%**. The justification (motivation) for each asset must be a single clear sentence **in Italian ("in Italiano")**. """ \ No newline at end of file diff --git a/src/app/markets/base.py b/src/app/markets/base.py index 635a2c6..032f8aa 100644 --- a/src/app/markets/base.py +++ b/src/app/markets/base.py @@ -1,5 +1,5 @@ from coinbase.rest.types.product_types import Candle, GetProductResponse - +from pydantic import BaseModel class BaseWrapper: """ @@ -15,17 +15,17 @@ class BaseWrapper: def get_historical_prices(self, asset_id: str = "BTC") -> list['Price']: raise NotImplementedError -class ProductInfo: +class ProductInfo(BaseModel): """ Informazioni sul prodotto, come ottenute dalle API di mercato. Implementa i metodi di conversione dai dati grezzi delle API. """ - id: str - symbol: str - price: float - volume_24h: float - status: str - quote_currency: str + id: str = "" + symbol: str = "" + price: float = 0.0 + volume_24h: float = 0.0 + status: str = "" + quote_currency: str = "" def from_coinbase(product_data: GetProductResponse) -> 'ProductInfo': product = ProductInfo() @@ -46,17 +46,17 @@ class ProductInfo: product.status = "" # Cryptocompare does not provide status return product -class Price: +class Price(BaseModel): """ Rappresenta i dati di prezzo per un asset, come ottenuti dalle API di mercato. Implementa i metodi di conversione dai dati grezzi delle API. """ - high: float - low: float - open: float - close: float - volume: float - time: str + high: float = 0.0 + low: float = 0.0 + open: float = 0.0 + close: float = 0.0 + volume: float = 0.0 + time: str = "" def from_coinbase(candle_data: Candle) -> 'Price': price = Price() diff --git a/src/app/models.py b/src/app/models.py index 494a922..12aae9c 100644 --- a/src/app/models.py +++ b/src/app/models.py @@ -1,14 +1,15 @@ import os import requests from enum import Enum +from pydantic import BaseModel from agno.agent import Agent -from agno.models.base import BaseModel +from agno.models.base import Model from agno.models.google import Gemini from agno.models.ollama import Ollama from agno.utils.log import log_warning -class Models(Enum): +class AppModels(Enum): """ Enum per i modelli supportati. Aggiungere nuovi modelli qui se necessario. @@ -21,7 +22,7 @@ class Models(Enum): OLLAMA_QWEN = "qwen3:latest" # + good + fast (8b) @staticmethod - def availables_local() -> list['Models']: + def availables_local() -> list['AppModels']: """ Controlla quali provider di modelli LLM locali sono disponibili. Ritorna una lista di provider disponibili. @@ -34,13 +35,13 @@ class Models(Enum): availables = [] result = result.text - if Models.OLLAMA_GPT.value in result: - availables.append(Models.OLLAMA_GPT) - if Models.OLLAMA_QWEN.value in result: - availables.append(Models.OLLAMA_QWEN) + if AppModels.OLLAMA_GPT.value in result: + availables.append(AppModels.OLLAMA_GPT) + if AppModels.OLLAMA_QWEN.value in result: + availables.append(AppModels.OLLAMA_QWEN) return availables - def availables_online() -> list['Models']: + def availables_online() -> list['AppModels']: """ Controlla quali provider di modelli LLM online hanno le loro API keys disponibili come variabili d'ambiente e ritorna una lista di provider disponibili. @@ -49,12 +50,12 @@ class Models(Enum): log_warning("No GOOGLE_API_KEY set in environment variables.") return [] availables = [] - availables.append(Models.GEMINI) - availables.append(Models.GEMINI_PRO) + availables.append(AppModels.GEMINI) + availables.append(AppModels.GEMINI_PRO) return availables @staticmethod - def availables() -> list['Models']: + def availables() -> list['AppModels']: """ Controlla quali provider di modelli LLM locali sono disponibili e quali provider di modelli LLM online hanno le loro API keys disponibili come variabili @@ -64,8 +65,8 @@ class Models(Enum): 2. Ollama (locale) """ availables = [ - *Models.availables_online(), - *Models.availables_local() + *AppModels.availables_online(), + *AppModels.availables_local() ] assert availables, "No valid model API keys set in environment variables." return availables @@ -94,7 +95,7 @@ class Models(Enum): return response[start:end + 1].strip() - def get_model(self, instructions:str) -> BaseModel: + def get_model(self, instructions:str) -> Model: """ Restituisce un'istanza del modello specificato. instructions: istruzioni da passare al modello (system prompt). @@ -102,14 +103,14 @@ class Models(Enum): Raise ValueError se il modello non è supportato. """ name = self.value - if self in {Models.GEMINI, Models.GEMINI_PRO}: + if self in {AppModels.GEMINI, AppModels.GEMINI_PRO}: return Gemini(name, instructions=[instructions]) - elif self in {Models.OLLAMA_GPT, Models.OLLAMA_QWEN}: + elif self in {AppModels.OLLAMA_GPT, AppModels.OLLAMA_QWEN}: return Ollama(name, instructions=[instructions]) raise ValueError(f"Modello non supportato: {self}") - def get_agent(self, instructions: str, name: str = "") -> Agent: + def get_agent(self, instructions: str, name: str = "", output: BaseModel | None = None) -> Agent: """ Costruisce un agente con il modello e le istruzioni specificate. instructions: istruzioni da passare al modello (system prompt). @@ -120,6 +121,6 @@ class Models(Enum): name=name, retries=2, delay_between_retries=5, # seconds - use_json_mode=True, # utile per fare in modo che l'agente risponda in JSON (anche se sembra essere solo placebo) + output_schema=output # se si usa uno schema di output, lo si passa qui # TODO Eventuali altri parametri da mettere all'agente anche se si possono comunque assegnare dopo la creazione ) diff --git a/src/app/tool.py b/src/app/tool.py index efb395e..d0b3ca0 100644 --- a/src/app/tool.py +++ b/src/app/tool.py @@ -1,9 +1,8 @@ from app.agents.news_agent import NewsAgent from app.agents.social_agent import SocialAgent -from app.agents.predictor import PredictorStyle -from app.agents import predictor +from app.agents.predictor import PredictorStyle, PredictorInput, PredictorOutput, PREDICTOR_INSTRUCTIONS from app.markets import MarketAPIs -from app.models import Models +from app.models import AppModels from agno.utils.log import log_info class ToolAgent: @@ -15,7 +14,7 @@ class ToolAgent: """ Inizializza l'agente con i modelli disponibili, gli stili e l'API di mercato. """ - self.available_models = Models.availables() + self.available_models = AppModels.availables() self.all_styles = list(PredictorStyle) self.style = self.all_styles[0] # Default to the first style @@ -31,7 +30,7 @@ class ToolAgent: # TODO https://docs.agno.com/introduction # Inoltre permette di creare dei team e workflow di agenti più facilmente self.chosen_model = self.available_models[index] - self.predictor = self.chosen_model.get_agent(predictor.instructions()) + self.predictor = self.chosen_model.get_agent(PREDICTOR_INSTRUCTIONS, output=PredictorOutput) self.news_agent = NewsAgent() self.social_agent = SocialAgent() @@ -64,18 +63,17 @@ class ToolAgent: sentiment = f"{news_sentiment}\n{social_sentiment}" # Step 3: previsione - inputs = predictor.prepare_inputs( - data=market_data, - style=self.style, - sentiment=sentiment - ) - - prediction = self.predictor.run(inputs) - output = Models.extract_json_str_from_response(prediction.content) + inputs = PredictorInput(data=market_data, style=self.style, sentiment=sentiment) + result = self.predictor.run(inputs) + prediction: PredictorOutput = result.content log_info(f"End of prediction") market_data = "\n".join([f"{product.symbol}: {product.price}" for product in market_data]) - return f"{market_data}\n{sentiment}\n\n📈 Consiglio finale:\n{output}" + output = f"[{prediction.strategy}]\nPortafoglio:\n" + "\n".join( + [f"{item.asset} ({item.percentage}%): {item.motivation}" for item in prediction.portfolio] + ) + + return f"INPUT:\n{market_data}\n{sentiment}\n\n\nOUTPUT:\n{output}" def list_providers(self) -> list[str]: """ diff --git a/tests/agents/test_predictor.py b/tests/agents/test_predictor.py index 7126833..c99104b 100644 --- a/tests/agents/test_predictor.py +++ b/tests/agents/test_predictor.py @@ -1,19 +1,29 @@ -import json import pytest -from app.agents import predictor -from app.models import Models +from app.agents.predictor import PREDICTOR_INSTRUCTIONS, PredictorInput, PredictorOutput, PredictorStyle +from app.markets.base import ProductInfo +from app.models import AppModels -def unified_checks(model: Models, input): - llm = model.get_agent(predictor.instructions()) +def unified_checks(model: AppModels, input): + llm = model.get_agent(PREDICTOR_INSTRUCTIONS, output=PredictorOutput) result = llm.run(input) + content = result.content - print(result.content) - potential_json = Models.extract_json_str_from_response(result.content) - content = json.loads(potential_json) # Verifica che l'output sia un JSON valido - - assert content['strategia'] is not None - assert isinstance(content['portafoglio'], list) - assert abs(sum(item['percentuale'] for item in content['portafoglio']) - 100) < 0.01 # La somma deve essere esattamente 100 + assert isinstance(content, PredictorOutput) + assert content.strategy not in (None, "", "null") + assert isinstance(content.strategy, str) + assert isinstance(content.portfolio, list) + assert len(content.portfolio) > 0 + for item in content.portfolio: + assert item.asset not in (None, "", "null") + assert isinstance(item.asset, str) + assert item.percentage > 0 + assert item.percentage <= 100 + assert isinstance(item.percentage, (int, float)) + assert item.motivation not in (None, "", "null") + assert isinstance(item.motivation, str) + # La somma delle percentuali deve essere esattamente 100 + total_percentage = sum(item.percentage for item in content.portfolio) + assert abs(total_percentage - 100) < 0.01 # Permette una piccola tolleranza per errori di arrotondamento class TestPredictor: @@ -21,23 +31,19 @@ class TestPredictor: def inputs(self): data = [] for symbol, price in [("BTC", 60000.00), ("ETH", 3500.00), ("SOL", 150.00)]: - product_info = predictor.ProductInfo() + product_info = ProductInfo() product_info.symbol = symbol product_info.price = price data.append(product_info) - return predictor.prepare_inputs( - data=data, - style=predictor.PredictorStyle.AGGRESSIVE, - sentiment="positivo" - ) + return PredictorInput(data=data, style=PredictorStyle.AGGRESSIVE, sentiment="positivo") def test_gemini_model_output(self, inputs): - unified_checks(Models.GEMINI, inputs) + unified_checks(AppModels.GEMINI, inputs) + + def test_ollama_qwen_model_output(self, inputs): + unified_checks(AppModels.OLLAMA_QWEN, inputs) @pytest.mark.slow def test_ollama_gpt_oss_model_output(self, inputs): - unified_checks(Models.OLLAMA_GPT, inputs) - - def test_ollama_qwen_model_output(self, inputs): - unified_checks(Models.OLLAMA_QWEN, inputs) + unified_checks(AppModels.OLLAMA_GPT, inputs)