Pydantic
- use Pydantic for input & output for models - update ToolAgent to utilize new model definitions - improve test cases for consistency
This commit is contained in:
@@ -1,19 +1,29 @@
|
||||
import json
|
||||
import pytest
|
||||
from app.agents import predictor
|
||||
from app.models import Models
|
||||
from app.agents.predictor import PREDICTOR_INSTRUCTIONS, PredictorInput, PredictorOutput, PredictorStyle
|
||||
from app.markets.base import ProductInfo
|
||||
from app.models import AppModels
|
||||
|
||||
def unified_checks(model: Models, input):
|
||||
llm = model.get_agent(predictor.instructions())
|
||||
def unified_checks(model: AppModels, input):
|
||||
llm = model.get_agent(PREDICTOR_INSTRUCTIONS, output=PredictorOutput)
|
||||
result = llm.run(input)
|
||||
content = result.content
|
||||
|
||||
print(result.content)
|
||||
potential_json = Models.extract_json_str_from_response(result.content)
|
||||
content = json.loads(potential_json) # Verifica che l'output sia un JSON valido
|
||||
|
||||
assert content['strategia'] is not None
|
||||
assert isinstance(content['portafoglio'], list)
|
||||
assert abs(sum(item['percentuale'] for item in content['portafoglio']) - 100) < 0.01 # La somma deve essere esattamente 100
|
||||
assert isinstance(content, PredictorOutput)
|
||||
assert content.strategy not in (None, "", "null")
|
||||
assert isinstance(content.strategy, str)
|
||||
assert isinstance(content.portfolio, list)
|
||||
assert len(content.portfolio) > 0
|
||||
for item in content.portfolio:
|
||||
assert item.asset not in (None, "", "null")
|
||||
assert isinstance(item.asset, str)
|
||||
assert item.percentage > 0
|
||||
assert item.percentage <= 100
|
||||
assert isinstance(item.percentage, (int, float))
|
||||
assert item.motivation not in (None, "", "null")
|
||||
assert isinstance(item.motivation, str)
|
||||
# La somma delle percentuali deve essere esattamente 100
|
||||
total_percentage = sum(item.percentage for item in content.portfolio)
|
||||
assert abs(total_percentage - 100) < 0.01 # Permette una piccola tolleranza per errori di arrotondamento
|
||||
|
||||
class TestPredictor:
|
||||
|
||||
@@ -21,23 +31,19 @@ class TestPredictor:
|
||||
def inputs(self):
|
||||
data = []
|
||||
for symbol, price in [("BTC", 60000.00), ("ETH", 3500.00), ("SOL", 150.00)]:
|
||||
product_info = predictor.ProductInfo()
|
||||
product_info = ProductInfo()
|
||||
product_info.symbol = symbol
|
||||
product_info.price = price
|
||||
data.append(product_info)
|
||||
|
||||
return predictor.prepare_inputs(
|
||||
data=data,
|
||||
style=predictor.PredictorStyle.AGGRESSIVE,
|
||||
sentiment="positivo"
|
||||
)
|
||||
return PredictorInput(data=data, style=PredictorStyle.AGGRESSIVE, sentiment="positivo")
|
||||
|
||||
def test_gemini_model_output(self, inputs):
|
||||
unified_checks(Models.GEMINI, inputs)
|
||||
unified_checks(AppModels.GEMINI, inputs)
|
||||
|
||||
def test_ollama_qwen_model_output(self, inputs):
|
||||
unified_checks(AppModels.OLLAMA_QWEN, inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_ollama_gpt_oss_model_output(self, inputs):
|
||||
unified_checks(Models.OLLAMA_GPT, inputs)
|
||||
|
||||
def test_ollama_qwen_model_output(self, inputs):
|
||||
unified_checks(Models.OLLAMA_QWEN, inputs)
|
||||
unified_checks(AppModels.OLLAMA_GPT, inputs)
|
||||
|
||||
Reference in New Issue
Block a user