From 22715237e10dc77c99a889e5afd5ee7ce5d871d0 Mon Sep 17 00:00:00 2001 From: Berack96 Date: Tue, 21 Oct 2025 16:57:28 +0200 Subject: [PATCH] Refactor configs dei modelli + fix chat interface default selection --- src/app/configs.py | 105 ++++++++++++++++++++------------------ src/app/interface/chat.py | 6 ++- 2 files changed, 58 insertions(+), 53 deletions(-) diff --git a/src/app/configs.py b/src/app/configs.py index 45b5b01..901de66 100644 --- a/src/app/configs.py +++ b/src/app/configs.py @@ -54,15 +54,21 @@ class AppModel(BaseModel): output_schema=output_schema ) + + class APIConfig(BaseModel): retry_attempts: int = 3 retry_delay_seconds: int = 2 + + class Strategy(BaseModel): name: str = "Conservative" label: str = "Conservative" description: str = "Focus on low-risk investments with steady returns." + + class ModelsConfig(BaseModel): gemini: list[AppModel] = [AppModel()] ollama: list[AppModel] = [] @@ -71,6 +77,53 @@ class ModelsConfig(BaseModel): def all_models(self) -> list[AppModel]: return self.gemini + self.ollama + def validate_models(self) -> None: + """ + Validate the configured models for each provider. + """ + self.__validate_online_models(self.gemini, clazz=Gemini, key="GOOGLE_API_KEY") + self.__validate_ollama_models() + + def __validate_online_models(self, models: list[AppModel], clazz: type[Model], key: str | None = None) -> None: + """ + Validate models for online providers like Gemini. + Args: + models: list of AppModel instances to validate + clazz: class of the model (e.g. Gemini) + key: API key required for the provider (optional) + """ + if key and os.getenv(key) is None: + log.warning(f"No {key} set in environment variables for provider.") + models.clear() + return + + for model in models: + model.model = clazz + + def __validate_ollama_models(self) -> None: + """ + Validate models for the Ollama provider. + """ + try: + models_list = ollama.list() + availables = {model['model'] for model in models_list['models']} + not_availables: list[str] = [] + + for model in self.ollama: + if model.name in availables: + model.model = Ollama + else: + not_availables.append(model.name) + if not_availables: + log.warning(f"Ollama models not available: {not_availables}") + + self.ollama = [model for model in self.ollama if model.model] + + except Exception as e: + log.warning(f"Ollama is not running or not reachable: {e}") + + + class AgentsConfigs(BaseModel): strategy: str = "Conservative" team_model: str = "gemini-2.0-flash" @@ -118,7 +171,7 @@ class AppConfig(BaseModel): super().__init__(*args, **kwargs) self.set_logging_level() - self.validate_models() + self.models.validate_models() self._initialized = True def get_model_by_name(self, name: str) -> AppModel: @@ -186,53 +239,3 @@ class AppConfig(BaseModel): logger = logging.getLogger(logger_name) logger.handlers.clear() logger.propagate = True - - def validate_models(self) -> None: - """ - Validate the configured models for each provider. - """ - self.__validate_online_models("gemini", clazz=Gemini, key="GOOGLE_API_KEY") - self.__validate_ollama_models() - - def __validate_online_models(self, provider: str, clazz: type[Model], key: str | None = None) -> None: - """ - Validate models for online providers like Gemini. - Args: - provider: name of the provider (e.g. "gemini") - clazz: class of the model (e.g. Gemini) - key: API key required for the provider (optional) - """ - if getattr(self.models, provider) is None: - log.warning(f"No models configured for provider '{provider}'.") - - models: list[AppModel] = getattr(self.models, provider) - if key and os.getenv(key) is None: - log.warning(f"No {key} set in environment variables for {provider}.") - models.clear() - return - - for model in models: - model.model = clazz - - def __validate_ollama_models(self) -> None: - """ - Validate models for the Ollama provider. - """ - try: - models_list = ollama.list() - availables = {model['model'] for model in models_list['models']} - not_availables: list[str] = [] - - for model in self.models.ollama: - if model.name in availables: - model.model = Ollama - else: - not_availables.append(model.name) - if not_availables: - log.warning(f"Ollama models not available: {not_availables}") - - self.models.ollama = [model for model in self.models.ollama if model.model] - - except Exception as e: - log.warning(f"Ollama is not running or not reachable: {e}") - diff --git a/src/app/interface/chat.py b/src/app/interface/chat.py index 150197b..37529f4 100644 --- a/src/app/interface/chat.py +++ b/src/app/interface/chat.py @@ -83,13 +83,15 @@ class ChatManager: label="Modello da usare" ) provider.change(fn=self.inputs.choose_team_leader, inputs=provider, outputs=None) + provider.value = self.inputs.team_leader_model.label - style = gr.Dropdown( + strategy = gr.Dropdown( choices=self.inputs.list_strategies_names(), type="index", label="Stile di investimento" ) - style.change(fn=self.inputs.choose_strategy, inputs=style, outputs=None) + strategy.change(fn=self.inputs.choose_strategy, inputs=strategy, outputs=None) + strategy.value = self.inputs.strategy.label chat = gr.ChatInterface( fn=self.gradio_respond