Compare commits
13 Commits
fix-config
...
62-aggrega
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ed2342dbd | ||
| 15b279faa1 | |||
| 30ddb76df7 | |||
| 3327bf8127 | |||
| 192adec7d0 | |||
| 6e2203f984 | |||
| c66332f240 | |||
| 968c137a5a | |||
|
|
83363f1b75 | ||
|
|
14b20ed07d | ||
|
|
c07938618a | ||
|
|
512bc4568e | ||
| 55858a7458 |
@@ -13,43 +13,56 @@ class ProductInfo(BaseModel):
|
||||
price: float = 0.0
|
||||
volume_24h: float = 0.0
|
||||
currency: str = ""
|
||||
provider: str = ""
|
||||
|
||||
@staticmethod
|
||||
def aggregate(products: dict[str, list['ProductInfo']]) -> list['ProductInfo']:
|
||||
def aggregate(products: dict[str, list['ProductInfo']], filter_currency: str="USD") -> list['ProductInfo']:
|
||||
"""
|
||||
Aggregates a list of ProductInfo by symbol.
|
||||
Args:
|
||||
products (dict[str, list[ProductInfo]]): Map provider -> list of ProductInfo
|
||||
filter_currency (str): If set, only products with this currency are considered. Defaults to "USD".
|
||||
Returns:
|
||||
list[ProductInfo]: List of ProductInfo aggregated by symbol
|
||||
"""
|
||||
|
||||
# Costruzione mappa symbol -> lista di ProductInfo
|
||||
symbols_infos: dict[str, list[ProductInfo]] = {}
|
||||
for _, product_list in products.items():
|
||||
# Costruzione mappa id -> lista di ProductInfo + lista di provider
|
||||
id_infos: dict[str, tuple[list[ProductInfo], list[str]]] = {}
|
||||
for provider, product_list in products.items():
|
||||
for product in product_list:
|
||||
symbols_infos.setdefault(product.symbol, []).append(product)
|
||||
if filter_currency and product.currency != filter_currency:
|
||||
continue
|
||||
id_value = product.id.upper().replace("-", "") # Normalizzazione id per compatibilità (es. BTC-USD -> btcusd)
|
||||
product_list, provider_list = id_infos.setdefault(id_value, ([], []) )
|
||||
product_list.append(product)
|
||||
provider_list.append(provider)
|
||||
|
||||
# Aggregazione per ogni symbol
|
||||
# Aggregazione per ogni id
|
||||
aggregated_products: list[ProductInfo] = []
|
||||
for symbol, product_list in symbols_infos.items():
|
||||
for id_value, (product_list, provider_list) in id_infos.items():
|
||||
product = ProductInfo()
|
||||
|
||||
product.id = f"{symbol}_AGGREGATED"
|
||||
product.symbol = symbol
|
||||
product.id = f"{id_value}_AGGREGATED"
|
||||
product.symbol = next(p.symbol for p in product_list if p.symbol)
|
||||
product.currency = next(p.currency for p in product_list if p.currency)
|
||||
|
||||
volume_sum = sum(p.volume_24h for p in product_list)
|
||||
product.volume_24h = volume_sum / len(product_list) if product_list else 0.0
|
||||
|
||||
prices = sum(p.price * p.volume_24h for p in product_list)
|
||||
product.price = (prices / volume_sum) if volume_sum > 0 else 0.0
|
||||
if volume_sum > 0:
|
||||
# Calcolo del prezzo pesato per volume (VWAP - Volume Weighted Average Price)
|
||||
prices_weighted = sum(p.price * p.volume_24h for p in product_list if p.volume_24h > 0)
|
||||
product.price = prices_weighted / volume_sum
|
||||
else:
|
||||
# Se non c'è volume, facciamo una media semplice dei prezzi
|
||||
valid_prices = [p.price for p in product_list if p.price > 0]
|
||||
product.price = sum(valid_prices) / len(valid_prices) if valid_prices else 0.0
|
||||
|
||||
product.provider = ",".join(provider_list)
|
||||
aggregated_products.append(product)
|
||||
return aggregated_products
|
||||
|
||||
|
||||
|
||||
class Price(BaseModel):
|
||||
"""
|
||||
Represents price data for an asset as obtained from market APIs.
|
||||
|
||||
@@ -37,6 +37,7 @@ class MarketAPIsTool(MarketWrapper, Toolkit):
|
||||
self.get_product,
|
||||
self.get_products,
|
||||
self.get_historical_prices,
|
||||
self.get_product_aggregated,
|
||||
self.get_products_aggregated,
|
||||
self.get_historical_prices_aggregated,
|
||||
],
|
||||
@@ -94,6 +95,27 @@ class MarketAPIsTool(MarketWrapper, Toolkit):
|
||||
"""
|
||||
return self.handler.try_call(lambda w: w.get_historical_prices(asset_id, limit))
|
||||
|
||||
@friendly_action("🧩 Aggrego le informazioni da più fonti...")
|
||||
def get_product_aggregated(self, asset_id: str) -> ProductInfo:
|
||||
"""
|
||||
Gets product information for a *single* asset from *all available providers* and *aggregates* the results.
|
||||
|
||||
This method queries all configured sources (Binance, YFinance, Coinbase, CryptoCompare)
|
||||
and combines the data using volume-weighted average price (VWAP) to provide
|
||||
the most accurate and comprehensive price data.
|
||||
|
||||
Args:
|
||||
asset_id (str): The asset ID to retrieve information for (e.g., "BTC", "ETH").
|
||||
|
||||
Returns:
|
||||
ProductInfo: A single ProductInfo object with aggregated data from all providers.
|
||||
The 'provider' field will list all sources used (e.g., "Binance, YFinance, Coinbase").
|
||||
|
||||
Raises:
|
||||
Exception: If all providers fail to return results.
|
||||
"""
|
||||
return self.get_products_aggregated([asset_id])[0]
|
||||
|
||||
@friendly_action("🧩 Aggrego le informazioni da più fonti...")
|
||||
def get_products_aggregated(self, asset_ids: list[str]) -> list[ProductInfo]:
|
||||
"""
|
||||
|
||||
@@ -16,7 +16,7 @@ BASE_URL = "https://finance.yahoo.com/markets/crypto/all/"
|
||||
|
||||
class CryptoSymbolsTools(Toolkit):
|
||||
"""
|
||||
Classe per ottenere i simboli delle criptovalute tramite Yahoo Finance.
|
||||
Class for obtaining cryptocurrency symbols via Yahoo Finance.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_file: str = 'resources/cryptos.csv'):
|
||||
@@ -34,29 +34,36 @@ class CryptoSymbolsTools(Toolkit):
|
||||
|
||||
def get_all_symbols(self) -> list[str]:
|
||||
"""
|
||||
Restituisce tutti i simboli delle criptovalute.
|
||||
Returns a complete list of all available cryptocurrency symbols (tickers).
|
||||
The list could be very long, prefer using 'get_symbols_by_name' for specific searches.
|
||||
|
||||
Returns:
|
||||
list[str]: Lista di tutti i simboli delle criptovalute.
|
||||
list[str]: A comprehensive list of all supported crypto symbols (e.g., "BTC-USD", "ETH-USD").
|
||||
"""
|
||||
return self.final_table['Symbol'].tolist() if not self.final_table.empty else []
|
||||
|
||||
def get_symbols_by_name(self, query: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Cerca i simboli che contengono la query.
|
||||
Searches the cryptocurrency database for assets matching a name or symbol.
|
||||
Use this to find the exact, correct symbol for a cryptocurrency name.
|
||||
Args:
|
||||
query (str): Query di ricerca.
|
||||
query (str): The name, partial name, or symbol to search for (e.g., "Bitcoin", "ETH").
|
||||
Returns:
|
||||
list[tuple[str, str]]: Lista di tuple (simbolo, nome) che contengono la query.
|
||||
list[tuple[str, str]]: A list of tuples, where each tuple contains
|
||||
the (symbol, full_name) of a matching asset.
|
||||
Returns an empty list if no matches are found.
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
positions = self.final_table['Name'].str.lower().str.contains(query_lower)
|
||||
return self.final_table[positions][['Symbol', 'Name']].apply(tuple, axis=1).tolist()
|
||||
positions = self.final_table['Name'].str.lower().str.contains(query_lower) | \
|
||||
self.final_table['Symbol'].str.lower().str.contains(query_lower)
|
||||
filtered_df = self.final_table[positions]
|
||||
return list(zip(filtered_df['Symbol'], filtered_df['Name']))
|
||||
|
||||
async def fetch_crypto_symbols(self, force_refresh: bool = False) -> None:
|
||||
"""
|
||||
Recupera tutti i simboli delle criptovalute da Yahoo Finance e li memorizza in cache.
|
||||
It retrieves all cryptocurrency symbols from Yahoo Finance and caches them.
|
||||
Args:
|
||||
force_refresh (bool): Se True, forza il recupero anche se i dati sono già in cache.
|
||||
force_refresh (bool): If True, it forces the retrieval even if the data are already in the cache.
|
||||
"""
|
||||
if not force_refresh and not self.final_table.empty:
|
||||
return
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
import threading
|
||||
import ollama
|
||||
import yaml
|
||||
import importlib
|
||||
import logging.config
|
||||
from typing import Any, ClassVar
|
||||
from pydantic import BaseModel
|
||||
@@ -68,34 +67,7 @@ class APIConfig(BaseModel):
|
||||
news_providers: list[str] = []
|
||||
social_providers: list[str] = []
|
||||
|
||||
def validate_providers(self) -> None:
|
||||
"""
|
||||
Validate that the configured providers are supported.
|
||||
Raises:
|
||||
ValueError if any provider is not supported.
|
||||
"""
|
||||
modules = [
|
||||
('app.api.markets', self.market_providers),
|
||||
('app.api.news', self.news_providers),
|
||||
('app.api.social', self.social_providers),
|
||||
]
|
||||
|
||||
for (module, config_providers) in modules:
|
||||
provider_type = module.split('.')[-1]
|
||||
mod = importlib.import_module(module)
|
||||
|
||||
supported_providers = set(getattr(mod, '__all__'))
|
||||
selected_providers = set(config_providers) & supported_providers
|
||||
|
||||
count = 0
|
||||
for provider in selected_providers:
|
||||
try:
|
||||
getattr(mod, provider)()
|
||||
count += 1
|
||||
except Exception as e:
|
||||
log.warning(f"Error occurred while checking {provider_type} provider '{provider}': {e}")
|
||||
if count == 0:
|
||||
raise ValueError(f"No valid {provider_type} providers found or defined in configs. Available: {supported_providers}")
|
||||
|
||||
class Strategy(BaseModel):
|
||||
name: str = "Conservative"
|
||||
@@ -239,7 +211,6 @@ class AppConfig(BaseModel):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.set_logging_level()
|
||||
self.api.validate_providers()
|
||||
self.models.validate_models()
|
||||
self.agents.validate_defaults(self)
|
||||
self._initialized = True
|
||||
|
||||
@@ -22,7 +22,6 @@ def pytest_configure(config:pytest.Config):
|
||||
("social", "marks tests that use social media"),
|
||||
("wrapper", "marks tests for wrapper handler"),
|
||||
|
||||
("configs", "marks tests for configuration classes"),
|
||||
("tools", "marks tests for tools"),
|
||||
("aggregator", "marks tests for market data aggregator"),
|
||||
|
||||
|
||||
@@ -1,345 +0,0 @@
|
||||
from typing import Any, Generator
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
import yaml
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app.configs import AppConfig, ModelsConfig, APIConfig, AgentsConfigs, Strategy, AppModel, Model
|
||||
|
||||
|
||||
@pytest.mark.configs
|
||||
class TestAppConfig:
|
||||
|
||||
@pytest.fixture
|
||||
def valid_config_data(self) -> dict[str, Any]:
|
||||
return {
|
||||
'port': 8080,
|
||||
'gradio_share': True,
|
||||
'logging_level': 'DEBUG',
|
||||
'strategies': [
|
||||
{'name': 'TestStrategy', 'label': 'Test', 'description': 'Test strategy'}
|
||||
],
|
||||
'models': {
|
||||
'gemini': [{'name': 'gemini-test', 'label': 'Gemini Test'}],
|
||||
'ollama': [{'name': 'test-model', 'label': 'Test Model'}]
|
||||
},
|
||||
'api': {
|
||||
'retry_attempts': 5,
|
||||
'market_providers': ['YFinanceWrapper'],
|
||||
'news_providers': ['DuckDuckGoWrapper'],
|
||||
'social_providers': ['RedditWrapper']
|
||||
},
|
||||
'agents': {
|
||||
'strategy': 'TestStrategy',
|
||||
'team_model': 'gemini-test'
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def temp_config_file(self, valid_config_data: dict[str, Any]) -> Generator[str, None, None]:
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
yaml.dump(valid_config_data, f)
|
||||
yield f.name
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_load_valid_config(self, temp_config_file: str):
|
||||
"""Test caricamento di un file di configurazione valido"""
|
||||
with patch.object(APIConfig, 'validate_providers'), \
|
||||
patch.object(ModelsConfig, 'validate_models'), \
|
||||
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||
|
||||
config = AppConfig.load(temp_config_file)
|
||||
assert config.port == 8080
|
||||
assert config.gradio_share is True
|
||||
assert config.logging_level == 'DEBUG'
|
||||
assert len(config.strategies) == 1
|
||||
assert config.strategies[0].name == 'TestStrategy'
|
||||
|
||||
def test_load_nonexistent_file(self):
|
||||
"""Test caricamento di un file inesistente"""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
AppConfig.load("nonexistent_file.yaml")
|
||||
|
||||
def test_load_invalid_yaml(self):
|
||||
"""Test caricamento di un file YAML malformato"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
f.write("invalid: yaml: content: [")
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
AppConfig.load(temp_file)
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_singleton_pattern(self, temp_config_file: str):
|
||||
"""Test che AppConfig sia un singleton"""
|
||||
with patch.object(APIConfig, 'validate_providers'), \
|
||||
patch.object(ModelsConfig, 'validate_models'), \
|
||||
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||
|
||||
# Reset singleton for test
|
||||
if hasattr(AppConfig, 'instance'):
|
||||
delattr(AppConfig, 'instance')
|
||||
|
||||
config1 = AppConfig.load(temp_config_file)
|
||||
config2 = AppConfig.load(temp_config_file)
|
||||
assert config1 is config2
|
||||
|
||||
def test_get_model_by_name_success(self, valid_config_data: dict[str, Any]):
|
||||
"""Test recupero modello esistente"""
|
||||
with patch.object(APIConfig, 'validate_providers'), \
|
||||
patch.object(ModelsConfig, 'validate_models'), \
|
||||
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||
|
||||
config = AppConfig(**valid_config_data)
|
||||
model = config.get_model_by_name('gemini-test')
|
||||
assert model.name == 'gemini-test'
|
||||
assert model.label == 'Gemini Test'
|
||||
|
||||
def test_get_model_by_name_not_found(self, valid_config_data: dict[str, Any]):
|
||||
"""Test recupero modello inesistente"""
|
||||
with patch.object(APIConfig, 'validate_providers'), \
|
||||
patch.object(ModelsConfig, 'validate_models'), \
|
||||
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||
|
||||
config = AppConfig(**valid_config_data)
|
||||
with pytest.raises(ValueError, match="Model with name 'nonexistent' not found"):
|
||||
config.get_model_by_name('nonexistent')
|
||||
|
||||
def test_get_strategy_by_name_success(self, valid_config_data: dict[str, Any]):
|
||||
"""Test recupero strategia esistente"""
|
||||
with patch.object(APIConfig, 'validate_providers'), \
|
||||
patch.object(ModelsConfig, 'validate_models'), \
|
||||
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||
|
||||
config = AppConfig(**valid_config_data)
|
||||
strategy = config.get_strategy_by_name('TestStrategy')
|
||||
assert strategy.name == 'TestStrategy'
|
||||
assert strategy.label == 'Test'
|
||||
|
||||
def test_get_strategy_by_name_not_found(self, valid_config_data: dict[str, Any]):
|
||||
"""Test recupero strategia inesistente"""
|
||||
with patch.object(APIConfig, 'validate_providers'), \
|
||||
patch.object(ModelsConfig, 'validate_models'), \
|
||||
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||
|
||||
config = AppConfig(**valid_config_data)
|
||||
with pytest.raises(ValueError, match="Strategy with name 'nonexistent' not found"):
|
||||
config.get_strategy_by_name('nonexistent')
|
||||
|
||||
|
||||
@pytest.mark.configs
|
||||
class TestModelsConfig:
|
||||
|
||||
def test_all_models_property(self):
|
||||
"""Test proprietà all_models che combina tutti i modelli"""
|
||||
config = ModelsConfig(
|
||||
gemini=[AppModel(name='gemini-1', label='G1')],
|
||||
ollama=[AppModel(name='ollama-1', label='O1')],
|
||||
gpt=[AppModel(name='gpt-1', label='GPT1')]
|
||||
)
|
||||
|
||||
all_models = config.all_models
|
||||
assert len(all_models) == 3
|
||||
names = [m.name for m in all_models]
|
||||
assert 'gemini-1' in names
|
||||
assert 'ollama-1' in names
|
||||
assert 'gpt-1' in names
|
||||
|
||||
@patch('app.configs.os.getenv')
|
||||
def test_validate_online_models_with_api_key(self, mock_getenv: MagicMock):
|
||||
"""Test validazione modelli online con API key presente"""
|
||||
mock_getenv.return_value = "test_api_key"
|
||||
|
||||
config = ModelsConfig(gemini=[AppModel(name='gemini-test')])
|
||||
config.validate_models()
|
||||
|
||||
assert config.gemini[0].model is not None
|
||||
|
||||
@patch('app.configs.os.getenv')
|
||||
def test_validate_online_models_without_api_key(self, mock_getenv: MagicMock):
|
||||
"""Test validazione modelli online senza API key"""
|
||||
mock_getenv.return_value = None
|
||||
|
||||
config = ModelsConfig(gemini=[AppModel(name='gemini-test')])
|
||||
config.validate_models()
|
||||
|
||||
assert len(config.gemini) == 0
|
||||
|
||||
@patch('app.configs.ollama.list')
|
||||
def test_validate_ollama_models_available(self, mock_ollama_list: MagicMock):
|
||||
"""Test validazione modelli Ollama disponibili"""
|
||||
mock_ollama_list.return_value = {
|
||||
'models': [{'model': 'test-model'}, {'model': 'another-model'}]
|
||||
}
|
||||
|
||||
config = ModelsConfig(ollama=[
|
||||
AppModel(name='test-model'),
|
||||
AppModel(name='unavailable-model')
|
||||
])
|
||||
config._ModelsConfig__validate_ollama_models() # type: ignore
|
||||
|
||||
assert len(config.ollama) == 1
|
||||
assert config.ollama[0].name == 'test-model'
|
||||
assert config.ollama[0].model is not None
|
||||
|
||||
@patch('app.configs.ollama.list')
|
||||
def test_validate_ollama_models_server_error(self, mock_ollama_list: MagicMock):
|
||||
"""Test validazione modelli Ollama con nessun modello disponibile"""
|
||||
mock_ollama_list.side_effect = Exception("Connection error")
|
||||
|
||||
config = ModelsConfig(ollama=[])
|
||||
config._ModelsConfig__validate_ollama_models() # type: ignore
|
||||
|
||||
assert len(config.ollama) == 0
|
||||
|
||||
|
||||
@pytest.mark.configs
|
||||
class TestAPIConfig:
|
||||
|
||||
@patch('app.configs.importlib.import_module')
|
||||
def test_validate_providers_success(self, mock_import: MagicMock):
|
||||
"""Test validazione provider con provider validi"""
|
||||
mock_module = MagicMock()
|
||||
mock_module.__all__ = ['TestWrapper']
|
||||
mock_module.TestWrapper = MagicMock()
|
||||
mock_import.return_value = mock_module
|
||||
|
||||
config = APIConfig(
|
||||
market_providers=['TestWrapper'],
|
||||
news_providers=['TestWrapper'],
|
||||
social_providers=['TestWrapper']
|
||||
)
|
||||
|
||||
config.validate_providers() # Should not raise
|
||||
|
||||
@patch('app.configs.importlib.import_module')
|
||||
def test_validate_providers_no_valid_providers(self, mock_import: MagicMock):
|
||||
"""Test validazione provider senza provider validi"""
|
||||
mock_module = MagicMock()
|
||||
mock_module.__all__ = ['ValidWrapper']
|
||||
mock_import.return_value = mock_module
|
||||
|
||||
config = APIConfig(market_providers=['InvalidWrapper'])
|
||||
|
||||
with pytest.raises(ValueError, match="No valid markets providers found"):
|
||||
config.validate_providers()
|
||||
|
||||
@patch('app.configs.importlib.import_module')
|
||||
def test_validate_providers_with_exceptions(self, mock_import: MagicMock):
|
||||
"""Test validazione provider con eccezioni durante l'inizializzazione"""
|
||||
mock_module = MagicMock()
|
||||
mock_module.__all__ = ['TestWrapper']
|
||||
mock_module.TestWrapper.side_effect = Exception("Init error")
|
||||
mock_import.return_value = mock_module
|
||||
|
||||
config = APIConfig(market_providers=['TestWrapper'])
|
||||
|
||||
with pytest.raises(ValueError, match="No valid markets providers found"):
|
||||
config.validate_providers()
|
||||
|
||||
|
||||
@pytest.mark.configs
|
||||
class TestAgentsConfigs:
|
||||
|
||||
def test_validate_defaults_success(self):
|
||||
"""Test validazione defaults con configurazioni valide"""
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_strategy_by_name.return_value = Strategy(name='TestStrategy')
|
||||
mock_config.get_model_by_name.return_value = AppModel(name='test-model')
|
||||
|
||||
agents_config = AgentsConfigs(
|
||||
strategy='TestStrategy',
|
||||
team_model='test-model',
|
||||
team_leader_model='test-model',
|
||||
query_analyzer_model='test-model',
|
||||
report_generation_model='test-model'
|
||||
)
|
||||
|
||||
agents_config.validate_defaults(mock_config) # Should not raise
|
||||
|
||||
def test_validate_defaults_invalid_strategy(self):
|
||||
"""Test validazione defaults con strategia inesistente"""
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_strategy_by_name.side_effect = ValueError("Strategy not found")
|
||||
|
||||
agents_config = AgentsConfigs(strategy='InvalidStrategy')
|
||||
|
||||
with pytest.raises(ValueError, match="Strategy not found"):
|
||||
agents_config.validate_defaults(mock_config)
|
||||
|
||||
def test_validate_defaults_invalid_model(self):
|
||||
"""Test validazione defaults con modello inesistente"""
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_strategy_by_name.return_value = Strategy(name='TestStrategy')
|
||||
mock_config.get_model_by_name.side_effect = ValueError("Model not found")
|
||||
|
||||
agents_config = AgentsConfigs(
|
||||
strategy='TestStrategy',
|
||||
team_model='invalid-model'
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Model not found"):
|
||||
agents_config.validate_defaults(mock_config)
|
||||
|
||||
|
||||
@pytest.mark.configs
|
||||
class TestAppModel:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance(self) -> tuple[MagicMock, type[Model]]:
|
||||
mock_instance = MagicMock()
|
||||
|
||||
# Use a concrete subclass of the application's Model base so pydantic validation passes,
|
||||
# and make instantiation return the mock instance.
|
||||
class DummyModel(Model):
|
||||
def __new__(cls, id: str, instructions: list[str]):
|
||||
return mock_instance
|
||||
return mock_instance, DummyModel
|
||||
|
||||
def test_get_model_success(self, mock_model_instance: tuple[MagicMock, type[Model]]):
|
||||
"""Test creazione modello con classe impostata"""
|
||||
app_model = AppModel(name='test-model', model=mock_model_instance[1])
|
||||
result = app_model.get_model("test instructions")
|
||||
assert result == mock_model_instance[0]
|
||||
|
||||
def test_get_model_no_class_set(self):
|
||||
"""Test creazione modello senza classe impostata"""
|
||||
app_model = AppModel(name='test-model')
|
||||
|
||||
with pytest.raises(ValueError, match="Model class for 'test-model' is not set"):
|
||||
app_model.get_model("test instructions")
|
||||
|
||||
def test_get_agent_success(self, mock_model_instance: tuple[MagicMock, type[Model]]):
|
||||
"""Test creazione agente con modello valido"""
|
||||
with patch('app.configs.Agent') as mock_agent_class:
|
||||
mock_agent_instance = MagicMock()
|
||||
mock_agent_class.return_value = mock_agent_instance
|
||||
|
||||
app_model = AppModel(name='test-model', model=mock_model_instance[1])
|
||||
result = app_model.get_agent(instructions="test instructions", name="agent_name")
|
||||
mock_agent_class.assert_called_once()
|
||||
assert result == mock_agent_instance
|
||||
|
||||
|
||||
@pytest.mark.configs
|
||||
class TestStrategy:
|
||||
|
||||
def test_strategy_defaults(self):
|
||||
"""Test valori di default per Strategy"""
|
||||
strategy = Strategy()
|
||||
assert strategy.name == "Conservative"
|
||||
assert strategy.label == "Conservative"
|
||||
assert "low-risk" in strategy.description.lower()
|
||||
|
||||
def test_strategy_custom_values(self):
|
||||
"""Test Strategy con valori personalizzati"""
|
||||
strategy = Strategy(
|
||||
name="Aggressive",
|
||||
label="High Risk",
|
||||
description="High-risk strategy"
|
||||
)
|
||||
assert strategy.name == "Aggressive"
|
||||
assert strategy.label == "High Risk"
|
||||
assert strategy.description == "High-risk strategy"
|
||||
@@ -9,11 +9,11 @@ class TestMarketDataAggregator:
|
||||
|
||||
def __product(self, symbol: str, price: float, volume: float, currency: str) -> ProductInfo:
|
||||
prod = ProductInfo()
|
||||
prod.id=f"{symbol}-{currency}"
|
||||
prod.symbol=symbol
|
||||
prod.price=price
|
||||
prod.volume_24h=volume
|
||||
prod.currency=currency
|
||||
prod.id = f"{symbol}-{currency}"
|
||||
prod.symbol = symbol
|
||||
prod.price = price
|
||||
prod.volume_24h = volume
|
||||
prod.currency = currency
|
||||
return prod
|
||||
|
||||
def __price(self, timestamp_s: int, high: float, low: float, open: float, close: float, volume: float) -> Price:
|
||||
@@ -38,12 +38,16 @@ class TestMarketDataAggregator:
|
||||
|
||||
info = aggregated[0]
|
||||
assert info is not None
|
||||
assert info.id == "BTCUSD_AGGREGATED"
|
||||
assert info.symbol == "BTC"
|
||||
assert info.currency == "USD"
|
||||
assert "Provider1" in info.provider
|
||||
assert "Provider2" in info.provider
|
||||
assert "Provider3" in info.provider
|
||||
|
||||
avg_weighted_price = (50000.0 * 1000.0 + 50100.0 * 1100.0 + 49900.0 * 900.0) / (1000.0 + 1100.0 + 900.0)
|
||||
assert info.price == pytest.approx(avg_weighted_price, rel=1e-3) # type: ignore
|
||||
assert info.volume_24h == pytest.approx(1000.0, rel=1e-3) # type: ignore
|
||||
assert info.currency == "USD"
|
||||
|
||||
def test_aggregate_product_info_multiple_symbols(self):
|
||||
products = {
|
||||
@@ -127,3 +131,80 @@ class TestMarketDataAggregator:
|
||||
assert aggregated[1].timestamp == timestamp_2h_ago
|
||||
assert aggregated[1].high == pytest.approx(50250.0, rel=1e-3) # type: ignore
|
||||
assert aggregated[1].low == pytest.approx(49850.0, rel=1e-3) # type: ignore
|
||||
|
||||
def test_aggregate_product_info_different_currencies(self):
|
||||
products = {
|
||||
"Provider1": [self.__product("BTC", 100000.0, 1000.0, "USD")],
|
||||
"Provider2": [self.__product("BTC", 70000.0, 800.0, "EUR")],
|
||||
}
|
||||
|
||||
aggregated = ProductInfo.aggregate(products)
|
||||
assert len(aggregated) == 1
|
||||
|
||||
info = aggregated[0]
|
||||
assert info is not None
|
||||
assert info.id == "BTCUSD_AGGREGATED"
|
||||
assert info.symbol == "BTC"
|
||||
assert info.currency == "USD" # Only USD products are kept
|
||||
# When currencies differ, only USD is aggregated (only Provider1 in this case)
|
||||
assert info.price == pytest.approx(100000.0, rel=1e-3) # type: ignore
|
||||
assert info.volume_24h == pytest.approx(1000.0, rel=1e-3) # type: ignore # Only USD volume
|
||||
|
||||
def test_aggregate_product_info_empty_providers(self):
|
||||
"""Test aggregate_product_info with some providers returning empty lists"""
|
||||
products: dict[str, list[ProductInfo]] = {
|
||||
"Provider1": [self.__product("BTC", 50000.0, 1000.0, "USD")],
|
||||
"Provider2": [],
|
||||
"Provider3": [self.__product("BTC", 50100.0, 1100.0, "USD")],
|
||||
}
|
||||
|
||||
aggregated = ProductInfo.aggregate(products)
|
||||
assert len(aggregated) == 1
|
||||
info = aggregated[0]
|
||||
assert info.symbol == "BTC"
|
||||
assert "Provider1" in info.provider
|
||||
assert "Provider2" not in info.provider
|
||||
assert "Provider3" in info.provider
|
||||
|
||||
def test_aggregate_product_info_mixed_symbols(self):
|
||||
"""Test that aggregate_product_info correctly separates different symbols"""
|
||||
products = {
|
||||
"Provider1": [
|
||||
self.__product("BTC", 50000.0, 1000.0, "USD"),
|
||||
self.__product("ETH", 4000.0, 2000.0, "USD"),
|
||||
self.__product("SOL", 100.0, 500.0, "USD"),
|
||||
],
|
||||
"Provider2": [
|
||||
self.__product("BTC", 50100.0, 1100.0, "USD"),
|
||||
self.__product("ETH", 4050.0, 2100.0, "USD"),
|
||||
],
|
||||
}
|
||||
|
||||
aggregated = ProductInfo.aggregate(products)
|
||||
assert len(aggregated) == 3
|
||||
|
||||
symbols = {p.symbol for p in aggregated}
|
||||
assert symbols == {"BTC", "ETH", "SOL"}
|
||||
|
||||
btc = next(p for p in aggregated if p.symbol == "BTC")
|
||||
assert "Provider1" in btc.provider and "Provider2" in btc.provider
|
||||
|
||||
sol = next(p for p in aggregated if p.symbol == "SOL")
|
||||
assert sol.provider == "Provider1" # Only one provider
|
||||
|
||||
def test_aggregate_product_info_zero_volume(self):
|
||||
"""Test aggregazione quando tutti i prodotti hanno volume zero"""
|
||||
products = {
|
||||
"Provider1": [self.__product("BTC", 50000.0, 0.0, "USD")],
|
||||
"Provider2": [self.__product("BTC", 50100.0, 0.0, "USD")],
|
||||
"Provider3": [self.__product("BTC", 49900.0, 0.0, "USD")],
|
||||
}
|
||||
|
||||
aggregated = ProductInfo.aggregate(products)
|
||||
assert len(aggregated) == 1
|
||||
|
||||
info = aggregated[0]
|
||||
# Con volume zero, dovrebbe usare la media semplice dei prezzi
|
||||
expected_price = (50000.0 + 50100.0 + 49900.0) / 3
|
||||
assert info.price == pytest.approx(expected_price, rel=1e-3) # type: ignore
|
||||
assert info.volume_24h == 0.0
|
||||
|
||||
Reference in New Issue
Block a user