Compare commits
1 Commits
fix-query-
...
fix-config
| Author | SHA1 | Date | |
|---|---|---|---|
| 24d73b6bf8 |
@@ -107,12 +107,7 @@ class Pipeline:
|
|||||||
def condition_query_ok(step_input: StepInput) -> StepOutput:
|
def condition_query_ok(step_input: StepInput) -> StepOutput:
|
||||||
val = step_input.previous_step_content
|
val = step_input.previous_step_content
|
||||||
stop = (not val.is_crypto) if isinstance(val, QueryOutputs) else True
|
stop = (not val.is_crypto) if isinstance(val, QueryOutputs) else True
|
||||||
return StepOutput(stop=stop, content=step_input.input)
|
return StepOutput(stop=stop)
|
||||||
|
|
||||||
def sanitization_output(step_input: StepInput) -> StepOutput:
|
|
||||||
val = step_input.previous_step_content
|
|
||||||
content = f"Query: {step_input.input}\n\nRetrieved data: {self.remove_think(str(val))}"
|
|
||||||
return StepOutput(content=content)
|
|
||||||
|
|
||||||
query_check = Step(name=PipelineEvent.QUERY_CHECK, agent=query_check)
|
query_check = Step(name=PipelineEvent.QUERY_CHECK, agent=query_check)
|
||||||
info_recovery = Step(name=PipelineEvent.INFO_RECOVERY, team=team)
|
info_recovery = Step(name=PipelineEvent.INFO_RECOVERY, team=team)
|
||||||
@@ -123,7 +118,6 @@ class Pipeline:
|
|||||||
query_check,
|
query_check,
|
||||||
condition_query_ok,
|
condition_query_ok,
|
||||||
info_recovery,
|
info_recovery,
|
||||||
sanitization_output,
|
|
||||||
report_generation
|
report_generation
|
||||||
])
|
])
|
||||||
|
|
||||||
@@ -156,22 +150,11 @@ class Pipeline:
|
|||||||
|
|
||||||
# Restituisce la risposta finale
|
# Restituisce la risposta finale
|
||||||
if content and isinstance(content, str):
|
if content and isinstance(content, str):
|
||||||
yield cls.remove_think(content)
|
think_str = "</think>"
|
||||||
|
think = content.rfind(think_str)
|
||||||
|
yield content[(think + len(think_str)):] if think != -1 else content
|
||||||
elif content and isinstance(content, QueryOutputs):
|
elif content and isinstance(content, QueryOutputs):
|
||||||
yield content.response
|
yield content.response
|
||||||
else:
|
else:
|
||||||
logging.error(f"No output from workflow: {content}")
|
logging.error(f"No output from workflow: {content}")
|
||||||
yield "Nessun output dal workflow, qualcosa è andato storto."
|
yield "Nessun output dal workflow, qualcosa è andato storto."
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def remove_think(cls, text: str) -> str:
|
|
||||||
"""
|
|
||||||
Rimuove la sezione di pensiero dal testo.
|
|
||||||
Args:
|
|
||||||
text: Il testo da pulire.
|
|
||||||
Returns:
|
|
||||||
Il testo senza la sezione di pensiero.
|
|
||||||
"""
|
|
||||||
think_str = "</think>"
|
|
||||||
think = text.rfind(think_str)
|
|
||||||
return text[(think + len(think_str)):] if think != -1 else text
|
|
||||||
|
|||||||
@@ -13,8 +13,3 @@
|
|||||||
- IS_CRYPTO: (empty)
|
- IS_CRYPTO: (empty)
|
||||||
- NOT_CRYPTO: "I can only analyze cryptocurrency topics."
|
- NOT_CRYPTO: "I can only analyze cryptocurrency topics."
|
||||||
- AMBIGUOUS: "Which cryptocurrency? (e.g., Bitcoin, Ethereum)"
|
- AMBIGUOUS: "Which cryptocurrency? (e.g., Bitcoin, Ethereum)"
|
||||||
|
|
||||||
**RULES:**
|
|
||||||
- DO NOT ANSWER the query.
|
|
||||||
- DO NOT PROVIDE ADDITIONAL INFORMATION.
|
|
||||||
- STOP instantly WHEN YOU CLASSIFY the query.
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
- NEVER use placeholders ("N/A", "Data not available") - OMIT section instead
|
- NEVER use placeholders ("N/A", "Data not available") - OMIT section instead
|
||||||
- NO example/placeholder data
|
- NO example/placeholder data
|
||||||
|
|
||||||
**INPUT:** You will get the original user query and a structured report with optional sections:
|
**INPUT:** Structured report from Team Leader with optional sections:
|
||||||
- Overall Summary
|
- Overall Summary
|
||||||
- Market & Price Data (opt)
|
- Market & Price Data (opt)
|
||||||
- News & Market Sentiment (opt)
|
- News & Market Sentiment (opt)
|
||||||
|
|||||||
@@ -79,6 +79,4 @@ Timestamp: {{CURRENT_DATE}}
|
|||||||
- Never modify MarketAgent prices
|
- Never modify MarketAgent prices
|
||||||
- Include all timestamps/sources
|
- Include all timestamps/sources
|
||||||
- Retry failed tasks (max 3)
|
- Retry failed tasks (max 3)
|
||||||
- Only report agent data
|
- Only report agent data
|
||||||
- DO NOT fabricate or add info
|
|
||||||
- DO NOT add sources if none provided
|
|
||||||
@@ -19,7 +19,7 @@ Historical: `{Asset, Period: {Start, End}, Data Points, Price Range: {Low, High}
|
|||||||
**MANDATORY RULES:**
|
**MANDATORY RULES:**
|
||||||
1. **Include timestamps** for every price data point
|
1. **Include timestamps** for every price data point
|
||||||
2. **Never fabricate** prices or dates - only report tool outputs
|
2. **Never fabricate** prices or dates - only report tool outputs
|
||||||
3. **Specify the data source** if provided, else state "source unavailable"
|
3. **Always specify the data source** (which API provided the data)
|
||||||
4. **Report data completeness**: If user asks for 30 days but got 7, state this explicitly
|
4. **Report data completeness**: If user asks for 30 days but got 7, state this explicitly
|
||||||
5. **Current date context**: Remind that data is as of {{CURRENT_DATE}}
|
5. **Current date context**: Remind that data is as of {{CURRENT_DATE}}
|
||||||
6. **Token Optimization**: Be extremely concise to save tokens. Provide all necessary data using as few words as possible. Exceed 100 words ONLY if absolutely necessary to include all required data points.
|
6. **Token Optimization**: Be extremely concise to save tokens. Provide all necessary data using as few words as possible. Exceed 100 words ONLY if absolutely necessary to include all required data points.
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
import ollama
|
import ollama
|
||||||
import yaml
|
import yaml
|
||||||
|
import importlib
|
||||||
import logging.config
|
import logging.config
|
||||||
from typing import Any, ClassVar
|
from typing import Any, ClassVar
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -67,7 +68,34 @@ class APIConfig(BaseModel):
|
|||||||
news_providers: list[str] = []
|
news_providers: list[str] = []
|
||||||
social_providers: list[str] = []
|
social_providers: list[str] = []
|
||||||
|
|
||||||
|
def validate_providers(self) -> None:
|
||||||
|
"""
|
||||||
|
Validate that the configured providers are supported.
|
||||||
|
Raises:
|
||||||
|
ValueError if any provider is not supported.
|
||||||
|
"""
|
||||||
|
modules = [
|
||||||
|
('app.api.markets', self.market_providers),
|
||||||
|
('app.api.news', self.news_providers),
|
||||||
|
('app.api.social', self.social_providers),
|
||||||
|
]
|
||||||
|
|
||||||
|
for (module, config_providers) in modules:
|
||||||
|
provider_type = module.split('.')[-1]
|
||||||
|
mod = importlib.import_module(module)
|
||||||
|
|
||||||
|
supported_providers = set(getattr(mod, '__all__'))
|
||||||
|
selected_providers = set(config_providers) & supported_providers
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for provider in selected_providers:
|
||||||
|
try:
|
||||||
|
getattr(mod, provider)()
|
||||||
|
count += 1
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Error occurred while checking {provider_type} provider '{provider}': {e}")
|
||||||
|
if count == 0:
|
||||||
|
raise ValueError(f"No valid {provider_type} providers found or defined in configs. Available: {supported_providers}")
|
||||||
|
|
||||||
class Strategy(BaseModel):
|
class Strategy(BaseModel):
|
||||||
name: str = "Conservative"
|
name: str = "Conservative"
|
||||||
@@ -211,6 +239,7 @@ class AppConfig(BaseModel):
|
|||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.set_logging_level()
|
self.set_logging_level()
|
||||||
|
self.api.validate_providers()
|
||||||
self.models.validate_models()
|
self.models.validate_models()
|
||||||
self.agents.validate_defaults(self)
|
self.agents.validate_defaults(self)
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ def pytest_configure(config:pytest.Config):
|
|||||||
("social", "marks tests that use social media"),
|
("social", "marks tests that use social media"),
|
||||||
("wrapper", "marks tests for wrapper handler"),
|
("wrapper", "marks tests for wrapper handler"),
|
||||||
|
|
||||||
|
("configs", "marks tests for configuration classes"),
|
||||||
("tools", "marks tests for tools"),
|
("tools", "marks tests for tools"),
|
||||||
("aggregator", "marks tests for market data aggregator"),
|
("aggregator", "marks tests for market data aggregator"),
|
||||||
|
|
||||||
|
|||||||
345
tests/utils/test_configs.py
Normal file
345
tests/utils/test_configs.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
from typing import Any, Generator
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import yaml
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from app.configs import AppConfig, ModelsConfig, APIConfig, AgentsConfigs, Strategy, AppModel, Model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.configs
|
||||||
|
class TestAppConfig:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_config_data(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
'port': 8080,
|
||||||
|
'gradio_share': True,
|
||||||
|
'logging_level': 'DEBUG',
|
||||||
|
'strategies': [
|
||||||
|
{'name': 'TestStrategy', 'label': 'Test', 'description': 'Test strategy'}
|
||||||
|
],
|
||||||
|
'models': {
|
||||||
|
'gemini': [{'name': 'gemini-test', 'label': 'Gemini Test'}],
|
||||||
|
'ollama': [{'name': 'test-model', 'label': 'Test Model'}]
|
||||||
|
},
|
||||||
|
'api': {
|
||||||
|
'retry_attempts': 5,
|
||||||
|
'market_providers': ['YFinanceWrapper'],
|
||||||
|
'news_providers': ['DuckDuckGoWrapper'],
|
||||||
|
'social_providers': ['RedditWrapper']
|
||||||
|
},
|
||||||
|
'agents': {
|
||||||
|
'strategy': 'TestStrategy',
|
||||||
|
'team_model': 'gemini-test'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_config_file(self, valid_config_data: dict[str, Any]) -> Generator[str, None, None]:
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||||
|
yaml.dump(valid_config_data, f)
|
||||||
|
yield f.name
|
||||||
|
os.unlink(f.name)
|
||||||
|
|
||||||
|
def test_load_valid_config(self, temp_config_file: str):
|
||||||
|
"""Test caricamento di un file di configurazione valido"""
|
||||||
|
with patch.object(APIConfig, 'validate_providers'), \
|
||||||
|
patch.object(ModelsConfig, 'validate_models'), \
|
||||||
|
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||||
|
|
||||||
|
config = AppConfig.load(temp_config_file)
|
||||||
|
assert config.port == 8080
|
||||||
|
assert config.gradio_share is True
|
||||||
|
assert config.logging_level == 'DEBUG'
|
||||||
|
assert len(config.strategies) == 1
|
||||||
|
assert config.strategies[0].name == 'TestStrategy'
|
||||||
|
|
||||||
|
def test_load_nonexistent_file(self):
|
||||||
|
"""Test caricamento di un file inesistente"""
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
AppConfig.load("nonexistent_file.yaml")
|
||||||
|
|
||||||
|
def test_load_invalid_yaml(self):
|
||||||
|
"""Test caricamento di un file YAML malformato"""
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||||
|
f.write("invalid: yaml: content: [")
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(yaml.YAMLError):
|
||||||
|
AppConfig.load(temp_file)
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_singleton_pattern(self, temp_config_file: str):
|
||||||
|
"""Test che AppConfig sia un singleton"""
|
||||||
|
with patch.object(APIConfig, 'validate_providers'), \
|
||||||
|
patch.object(ModelsConfig, 'validate_models'), \
|
||||||
|
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||||
|
|
||||||
|
# Reset singleton for test
|
||||||
|
if hasattr(AppConfig, 'instance'):
|
||||||
|
delattr(AppConfig, 'instance')
|
||||||
|
|
||||||
|
config1 = AppConfig.load(temp_config_file)
|
||||||
|
config2 = AppConfig.load(temp_config_file)
|
||||||
|
assert config1 is config2
|
||||||
|
|
||||||
|
def test_get_model_by_name_success(self, valid_config_data: dict[str, Any]):
|
||||||
|
"""Test recupero modello esistente"""
|
||||||
|
with patch.object(APIConfig, 'validate_providers'), \
|
||||||
|
patch.object(ModelsConfig, 'validate_models'), \
|
||||||
|
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||||
|
|
||||||
|
config = AppConfig(**valid_config_data)
|
||||||
|
model = config.get_model_by_name('gemini-test')
|
||||||
|
assert model.name == 'gemini-test'
|
||||||
|
assert model.label == 'Gemini Test'
|
||||||
|
|
||||||
|
def test_get_model_by_name_not_found(self, valid_config_data: dict[str, Any]):
|
||||||
|
"""Test recupero modello inesistente"""
|
||||||
|
with patch.object(APIConfig, 'validate_providers'), \
|
||||||
|
patch.object(ModelsConfig, 'validate_models'), \
|
||||||
|
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||||
|
|
||||||
|
config = AppConfig(**valid_config_data)
|
||||||
|
with pytest.raises(ValueError, match="Model with name 'nonexistent' not found"):
|
||||||
|
config.get_model_by_name('nonexistent')
|
||||||
|
|
||||||
|
def test_get_strategy_by_name_success(self, valid_config_data: dict[str, Any]):
|
||||||
|
"""Test recupero strategia esistente"""
|
||||||
|
with patch.object(APIConfig, 'validate_providers'), \
|
||||||
|
patch.object(ModelsConfig, 'validate_models'), \
|
||||||
|
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||||
|
|
||||||
|
config = AppConfig(**valid_config_data)
|
||||||
|
strategy = config.get_strategy_by_name('TestStrategy')
|
||||||
|
assert strategy.name == 'TestStrategy'
|
||||||
|
assert strategy.label == 'Test'
|
||||||
|
|
||||||
|
def test_get_strategy_by_name_not_found(self, valid_config_data: dict[str, Any]):
|
||||||
|
"""Test recupero strategia inesistente"""
|
||||||
|
with patch.object(APIConfig, 'validate_providers'), \
|
||||||
|
patch.object(ModelsConfig, 'validate_models'), \
|
||||||
|
patch.object(AgentsConfigs, 'validate_defaults'):
|
||||||
|
|
||||||
|
config = AppConfig(**valid_config_data)
|
||||||
|
with pytest.raises(ValueError, match="Strategy with name 'nonexistent' not found"):
|
||||||
|
config.get_strategy_by_name('nonexistent')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.configs
|
||||||
|
class TestModelsConfig:
|
||||||
|
|
||||||
|
def test_all_models_property(self):
|
||||||
|
"""Test proprietà all_models che combina tutti i modelli"""
|
||||||
|
config = ModelsConfig(
|
||||||
|
gemini=[AppModel(name='gemini-1', label='G1')],
|
||||||
|
ollama=[AppModel(name='ollama-1', label='O1')],
|
||||||
|
gpt=[AppModel(name='gpt-1', label='GPT1')]
|
||||||
|
)
|
||||||
|
|
||||||
|
all_models = config.all_models
|
||||||
|
assert len(all_models) == 3
|
||||||
|
names = [m.name for m in all_models]
|
||||||
|
assert 'gemini-1' in names
|
||||||
|
assert 'ollama-1' in names
|
||||||
|
assert 'gpt-1' in names
|
||||||
|
|
||||||
|
@patch('app.configs.os.getenv')
|
||||||
|
def test_validate_online_models_with_api_key(self, mock_getenv: MagicMock):
|
||||||
|
"""Test validazione modelli online con API key presente"""
|
||||||
|
mock_getenv.return_value = "test_api_key"
|
||||||
|
|
||||||
|
config = ModelsConfig(gemini=[AppModel(name='gemini-test')])
|
||||||
|
config.validate_models()
|
||||||
|
|
||||||
|
assert config.gemini[0].model is not None
|
||||||
|
|
||||||
|
@patch('app.configs.os.getenv')
|
||||||
|
def test_validate_online_models_without_api_key(self, mock_getenv: MagicMock):
|
||||||
|
"""Test validazione modelli online senza API key"""
|
||||||
|
mock_getenv.return_value = None
|
||||||
|
|
||||||
|
config = ModelsConfig(gemini=[AppModel(name='gemini-test')])
|
||||||
|
config.validate_models()
|
||||||
|
|
||||||
|
assert len(config.gemini) == 0
|
||||||
|
|
||||||
|
@patch('app.configs.ollama.list')
|
||||||
|
def test_validate_ollama_models_available(self, mock_ollama_list: MagicMock):
|
||||||
|
"""Test validazione modelli Ollama disponibili"""
|
||||||
|
mock_ollama_list.return_value = {
|
||||||
|
'models': [{'model': 'test-model'}, {'model': 'another-model'}]
|
||||||
|
}
|
||||||
|
|
||||||
|
config = ModelsConfig(ollama=[
|
||||||
|
AppModel(name='test-model'),
|
||||||
|
AppModel(name='unavailable-model')
|
||||||
|
])
|
||||||
|
config._ModelsConfig__validate_ollama_models() # type: ignore
|
||||||
|
|
||||||
|
assert len(config.ollama) == 1
|
||||||
|
assert config.ollama[0].name == 'test-model'
|
||||||
|
assert config.ollama[0].model is not None
|
||||||
|
|
||||||
|
@patch('app.configs.ollama.list')
|
||||||
|
def test_validate_ollama_models_server_error(self, mock_ollama_list: MagicMock):
|
||||||
|
"""Test validazione modelli Ollama con nessun modello disponibile"""
|
||||||
|
mock_ollama_list.side_effect = Exception("Connection error")
|
||||||
|
|
||||||
|
config = ModelsConfig(ollama=[])
|
||||||
|
config._ModelsConfig__validate_ollama_models() # type: ignore
|
||||||
|
|
||||||
|
assert len(config.ollama) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.configs
|
||||||
|
class TestAPIConfig:
|
||||||
|
|
||||||
|
@patch('app.configs.importlib.import_module')
|
||||||
|
def test_validate_providers_success(self, mock_import: MagicMock):
|
||||||
|
"""Test validazione provider con provider validi"""
|
||||||
|
mock_module = MagicMock()
|
||||||
|
mock_module.__all__ = ['TestWrapper']
|
||||||
|
mock_module.TestWrapper = MagicMock()
|
||||||
|
mock_import.return_value = mock_module
|
||||||
|
|
||||||
|
config = APIConfig(
|
||||||
|
market_providers=['TestWrapper'],
|
||||||
|
news_providers=['TestWrapper'],
|
||||||
|
social_providers=['TestWrapper']
|
||||||
|
)
|
||||||
|
|
||||||
|
config.validate_providers() # Should not raise
|
||||||
|
|
||||||
|
@patch('app.configs.importlib.import_module')
|
||||||
|
def test_validate_providers_no_valid_providers(self, mock_import: MagicMock):
|
||||||
|
"""Test validazione provider senza provider validi"""
|
||||||
|
mock_module = MagicMock()
|
||||||
|
mock_module.__all__ = ['ValidWrapper']
|
||||||
|
mock_import.return_value = mock_module
|
||||||
|
|
||||||
|
config = APIConfig(market_providers=['InvalidWrapper'])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No valid markets providers found"):
|
||||||
|
config.validate_providers()
|
||||||
|
|
||||||
|
@patch('app.configs.importlib.import_module')
|
||||||
|
def test_validate_providers_with_exceptions(self, mock_import: MagicMock):
|
||||||
|
"""Test validazione provider con eccezioni durante l'inizializzazione"""
|
||||||
|
mock_module = MagicMock()
|
||||||
|
mock_module.__all__ = ['TestWrapper']
|
||||||
|
mock_module.TestWrapper.side_effect = Exception("Init error")
|
||||||
|
mock_import.return_value = mock_module
|
||||||
|
|
||||||
|
config = APIConfig(market_providers=['TestWrapper'])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No valid markets providers found"):
|
||||||
|
config.validate_providers()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.configs
|
||||||
|
class TestAgentsConfigs:
|
||||||
|
|
||||||
|
def test_validate_defaults_success(self):
|
||||||
|
"""Test validazione defaults con configurazioni valide"""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get_strategy_by_name.return_value = Strategy(name='TestStrategy')
|
||||||
|
mock_config.get_model_by_name.return_value = AppModel(name='test-model')
|
||||||
|
|
||||||
|
agents_config = AgentsConfigs(
|
||||||
|
strategy='TestStrategy',
|
||||||
|
team_model='test-model',
|
||||||
|
team_leader_model='test-model',
|
||||||
|
query_analyzer_model='test-model',
|
||||||
|
report_generation_model='test-model'
|
||||||
|
)
|
||||||
|
|
||||||
|
agents_config.validate_defaults(mock_config) # Should not raise
|
||||||
|
|
||||||
|
def test_validate_defaults_invalid_strategy(self):
|
||||||
|
"""Test validazione defaults con strategia inesistente"""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get_strategy_by_name.side_effect = ValueError("Strategy not found")
|
||||||
|
|
||||||
|
agents_config = AgentsConfigs(strategy='InvalidStrategy')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Strategy not found"):
|
||||||
|
agents_config.validate_defaults(mock_config)
|
||||||
|
|
||||||
|
def test_validate_defaults_invalid_model(self):
|
||||||
|
"""Test validazione defaults con modello inesistente"""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get_strategy_by_name.return_value = Strategy(name='TestStrategy')
|
||||||
|
mock_config.get_model_by_name.side_effect = ValueError("Model not found")
|
||||||
|
|
||||||
|
agents_config = AgentsConfigs(
|
||||||
|
strategy='TestStrategy',
|
||||||
|
team_model='invalid-model'
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Model not found"):
|
||||||
|
agents_config.validate_defaults(mock_config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.configs
|
||||||
|
class TestAppModel:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_model_instance(self) -> tuple[MagicMock, type[Model]]:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
|
||||||
|
# Use a concrete subclass of the application's Model base so pydantic validation passes,
|
||||||
|
# and make instantiation return the mock instance.
|
||||||
|
class DummyModel(Model):
|
||||||
|
def __new__(cls, id: str, instructions: list[str]):
|
||||||
|
return mock_instance
|
||||||
|
return mock_instance, DummyModel
|
||||||
|
|
||||||
|
def test_get_model_success(self, mock_model_instance: tuple[MagicMock, type[Model]]):
|
||||||
|
"""Test creazione modello con classe impostata"""
|
||||||
|
app_model = AppModel(name='test-model', model=mock_model_instance[1])
|
||||||
|
result = app_model.get_model("test instructions")
|
||||||
|
assert result == mock_model_instance[0]
|
||||||
|
|
||||||
|
def test_get_model_no_class_set(self):
|
||||||
|
"""Test creazione modello senza classe impostata"""
|
||||||
|
app_model = AppModel(name='test-model')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Model class for 'test-model' is not set"):
|
||||||
|
app_model.get_model("test instructions")
|
||||||
|
|
||||||
|
def test_get_agent_success(self, mock_model_instance: tuple[MagicMock, type[Model]]):
|
||||||
|
"""Test creazione agente con modello valido"""
|
||||||
|
with patch('app.configs.Agent') as mock_agent_class:
|
||||||
|
mock_agent_instance = MagicMock()
|
||||||
|
mock_agent_class.return_value = mock_agent_instance
|
||||||
|
|
||||||
|
app_model = AppModel(name='test-model', model=mock_model_instance[1])
|
||||||
|
result = app_model.get_agent(instructions="test instructions", name="agent_name")
|
||||||
|
mock_agent_class.assert_called_once()
|
||||||
|
assert result == mock_agent_instance
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.configs
|
||||||
|
class TestStrategy:
|
||||||
|
|
||||||
|
def test_strategy_defaults(self):
|
||||||
|
"""Test valori di default per Strategy"""
|
||||||
|
strategy = Strategy()
|
||||||
|
assert strategy.name == "Conservative"
|
||||||
|
assert strategy.label == "Conservative"
|
||||||
|
assert "low-risk" in strategy.description.lower()
|
||||||
|
|
||||||
|
def test_strategy_custom_values(self):
|
||||||
|
"""Test Strategy con valori personalizzati"""
|
||||||
|
strategy = Strategy(
|
||||||
|
name="Aggressive",
|
||||||
|
label="High Risk",
|
||||||
|
description="High-risk strategy"
|
||||||
|
)
|
||||||
|
assert strategy.name == "Aggressive"
|
||||||
|
assert strategy.label == "High Risk"
|
||||||
|
assert strategy.description == "High-risk strategy"
|
||||||
Reference in New Issue
Block a user