Refactor configs dei modelli + fix chat interface default selection

This commit is contained in:
2025-10-21 16:57:28 +02:00
parent 2e092d3f25
commit 22715237e1
2 changed files with 58 additions and 53 deletions

View File

@@ -54,15 +54,21 @@ class AppModel(BaseModel):
output_schema=output_schema output_schema=output_schema
) )
class APIConfig(BaseModel): class APIConfig(BaseModel):
retry_attempts: int = 3 retry_attempts: int = 3
retry_delay_seconds: int = 2 retry_delay_seconds: int = 2
class Strategy(BaseModel): class Strategy(BaseModel):
name: str = "Conservative" name: str = "Conservative"
label: str = "Conservative" label: str = "Conservative"
description: str = "Focus on low-risk investments with steady returns." description: str = "Focus on low-risk investments with steady returns."
class ModelsConfig(BaseModel): class ModelsConfig(BaseModel):
gemini: list[AppModel] = [AppModel()] gemini: list[AppModel] = [AppModel()]
ollama: list[AppModel] = [] ollama: list[AppModel] = []
@@ -71,6 +77,53 @@ class ModelsConfig(BaseModel):
def all_models(self) -> list[AppModel]: def all_models(self) -> list[AppModel]:
return self.gemini + self.ollama return self.gemini + self.ollama
def validate_models(self) -> None:
"""
Validate the configured models for each provider.
"""
self.__validate_online_models(self.gemini, clazz=Gemini, key="GOOGLE_API_KEY")
self.__validate_ollama_models()
def __validate_online_models(self, models: list[AppModel], clazz: type[Model], key: str | None = None) -> None:
"""
Validate models for online providers like Gemini.
Args:
models: list of AppModel instances to validate
clazz: class of the model (e.g. Gemini)
key: API key required for the provider (optional)
"""
if key and os.getenv(key) is None:
log.warning(f"No {key} set in environment variables for provider.")
models.clear()
return
for model in models:
model.model = clazz
def __validate_ollama_models(self) -> None:
"""
Validate models for the Ollama provider.
"""
try:
models_list = ollama.list()
availables = {model['model'] for model in models_list['models']}
not_availables: list[str] = []
for model in self.ollama:
if model.name in availables:
model.model = Ollama
else:
not_availables.append(model.name)
if not_availables:
log.warning(f"Ollama models not available: {not_availables}")
self.ollama = [model for model in self.ollama if model.model]
except Exception as e:
log.warning(f"Ollama is not running or not reachable: {e}")
class AgentsConfigs(BaseModel): class AgentsConfigs(BaseModel):
strategy: str = "Conservative" strategy: str = "Conservative"
team_model: str = "gemini-2.0-flash" team_model: str = "gemini-2.0-flash"
@@ -118,7 +171,7 @@ class AppConfig(BaseModel):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.set_logging_level() self.set_logging_level()
self.validate_models() self.models.validate_models()
self._initialized = True self._initialized = True
def get_model_by_name(self, name: str) -> AppModel: def get_model_by_name(self, name: str) -> AppModel:
@@ -186,53 +239,3 @@ class AppConfig(BaseModel):
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
logger.handlers.clear() logger.handlers.clear()
logger.propagate = True logger.propagate = True
def validate_models(self) -> None:
"""
Validate the configured models for each provider.
"""
self.__validate_online_models("gemini", clazz=Gemini, key="GOOGLE_API_KEY")
self.__validate_ollama_models()
def __validate_online_models(self, provider: str, clazz: type[Model], key: str | None = None) -> None:
"""
Validate models for online providers like Gemini.
Args:
provider: name of the provider (e.g. "gemini")
clazz: class of the model (e.g. Gemini)
key: API key required for the provider (optional)
"""
if getattr(self.models, provider) is None:
log.warning(f"No models configured for provider '{provider}'.")
models: list[AppModel] = getattr(self.models, provider)
if key and os.getenv(key) is None:
log.warning(f"No {key} set in environment variables for {provider}.")
models.clear()
return
for model in models:
model.model = clazz
def __validate_ollama_models(self) -> None:
"""
Validate models for the Ollama provider.
"""
try:
models_list = ollama.list()
availables = {model['model'] for model in models_list['models']}
not_availables: list[str] = []
for model in self.models.ollama:
if model.name in availables:
model.model = Ollama
else:
not_availables.append(model.name)
if not_availables:
log.warning(f"Ollama models not available: {not_availables}")
self.models.ollama = [model for model in self.models.ollama if model.model]
except Exception as e:
log.warning(f"Ollama is not running or not reachable: {e}")

View File

@@ -83,13 +83,15 @@ class ChatManager:
label="Modello da usare" label="Modello da usare"
) )
provider.change(fn=self.inputs.choose_team_leader, inputs=provider, outputs=None) provider.change(fn=self.inputs.choose_team_leader, inputs=provider, outputs=None)
provider.value = self.inputs.team_leader_model.label
style = gr.Dropdown( strategy = gr.Dropdown(
choices=self.inputs.list_strategies_names(), choices=self.inputs.list_strategies_names(),
type="index", type="index",
label="Stile di investimento" label="Stile di investimento"
) )
style.change(fn=self.inputs.choose_strategy, inputs=style, outputs=None) strategy.change(fn=self.inputs.choose_strategy, inputs=strategy, outputs=None)
strategy.value = self.inputs.strategy.label
chat = gr.ChatInterface( chat = gr.ChatInterface(
fn=self.gradio_respond fn=self.gradio_respond