Refactor configs dei modelli + fix chat interface default selection
This commit is contained in:
@@ -54,6 +54,8 @@ class AppModel(BaseModel):
|
|||||||
output_schema=output_schema
|
output_schema=output_schema
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class APIConfig(BaseModel):
|
class APIConfig(BaseModel):
|
||||||
retry_attempts: int = 3
|
retry_attempts: int = 3
|
||||||
retry_delay_seconds: int = 2
|
retry_delay_seconds: int = 2
|
||||||
@@ -61,11 +63,15 @@ class APIConfig(BaseModel):
|
|||||||
news_providers: list[str] = []
|
news_providers: list[str] = []
|
||||||
social_providers: list[str] = []
|
social_providers: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Strategy(BaseModel):
|
class Strategy(BaseModel):
|
||||||
name: str = "Conservative"
|
name: str = "Conservative"
|
||||||
label: str = "Conservative"
|
label: str = "Conservative"
|
||||||
description: str = "Focus on low-risk investments with steady returns."
|
description: str = "Focus on low-risk investments with steady returns."
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsConfig(BaseModel):
|
class ModelsConfig(BaseModel):
|
||||||
gemini: list[AppModel] = [AppModel()]
|
gemini: list[AppModel] = [AppModel()]
|
||||||
ollama: list[AppModel] = []
|
ollama: list[AppModel] = []
|
||||||
@@ -74,6 +80,53 @@ class ModelsConfig(BaseModel):
|
|||||||
def all_models(self) -> list[AppModel]:
|
def all_models(self) -> list[AppModel]:
|
||||||
return self.gemini + self.ollama
|
return self.gemini + self.ollama
|
||||||
|
|
||||||
|
def validate_models(self) -> None:
|
||||||
|
"""
|
||||||
|
Validate the configured models for each provider.
|
||||||
|
"""
|
||||||
|
self.__validate_online_models(self.gemini, clazz=Gemini, key="GOOGLE_API_KEY")
|
||||||
|
self.__validate_ollama_models()
|
||||||
|
|
||||||
|
def __validate_online_models(self, models: list[AppModel], clazz: type[Model], key: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Validate models for online providers like Gemini.
|
||||||
|
Args:
|
||||||
|
models: list of AppModel instances to validate
|
||||||
|
clazz: class of the model (e.g. Gemini)
|
||||||
|
key: API key required for the provider (optional)
|
||||||
|
"""
|
||||||
|
if key and os.getenv(key) is None:
|
||||||
|
log.warning(f"No {key} set in environment variables for provider.")
|
||||||
|
models.clear()
|
||||||
|
return
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
model.model = clazz
|
||||||
|
|
||||||
|
def __validate_ollama_models(self) -> None:
|
||||||
|
"""
|
||||||
|
Validate models for the Ollama provider.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
models_list = ollama.list()
|
||||||
|
availables = {model['model'] for model in models_list['models']}
|
||||||
|
not_availables: list[str] = []
|
||||||
|
|
||||||
|
for model in self.ollama:
|
||||||
|
if model.name in availables:
|
||||||
|
model.model = Ollama
|
||||||
|
else:
|
||||||
|
not_availables.append(model.name)
|
||||||
|
if not_availables:
|
||||||
|
log.warning(f"Ollama models not available: {not_availables}")
|
||||||
|
|
||||||
|
self.ollama = [model for model in self.ollama if model.model]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Ollama is not running or not reachable: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AgentsConfigs(BaseModel):
|
class AgentsConfigs(BaseModel):
|
||||||
strategy: str = "Conservative"
|
strategy: str = "Conservative"
|
||||||
team_model: str = "gemini-2.0-flash"
|
team_model: str = "gemini-2.0-flash"
|
||||||
@@ -121,7 +174,7 @@ class AppConfig(BaseModel):
|
|||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.set_logging_level()
|
self.set_logging_level()
|
||||||
self.validate_models()
|
self.models.validate_models()
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def get_model_by_name(self, name: str) -> AppModel:
|
def get_model_by_name(self, name: str) -> AppModel:
|
||||||
@@ -189,53 +242,3 @@ class AppConfig(BaseModel):
|
|||||||
logger = logging.getLogger(logger_name)
|
logger = logging.getLogger(logger_name)
|
||||||
logger.handlers.clear()
|
logger.handlers.clear()
|
||||||
logger.propagate = True
|
logger.propagate = True
|
||||||
|
|
||||||
def validate_models(self) -> None:
|
|
||||||
"""
|
|
||||||
Validate the configured models for each provider.
|
|
||||||
"""
|
|
||||||
self.__validate_online_models("gemini", clazz=Gemini, key="GOOGLE_API_KEY")
|
|
||||||
self.__validate_ollama_models()
|
|
||||||
|
|
||||||
def __validate_online_models(self, provider: str, clazz: type[Model], key: str | None = None) -> None:
|
|
||||||
"""
|
|
||||||
Validate models for online providers like Gemini.
|
|
||||||
Args:
|
|
||||||
provider: name of the provider (e.g. "gemini")
|
|
||||||
clazz: class of the model (e.g. Gemini)
|
|
||||||
key: API key required for the provider (optional)
|
|
||||||
"""
|
|
||||||
if getattr(self.models, provider) is None:
|
|
||||||
log.warning(f"No models configured for provider '{provider}'.")
|
|
||||||
|
|
||||||
models: list[AppModel] = getattr(self.models, provider)
|
|
||||||
if key and os.getenv(key) is None:
|
|
||||||
log.warning(f"No {key} set in environment variables for {provider}.")
|
|
||||||
models.clear()
|
|
||||||
return
|
|
||||||
|
|
||||||
for model in models:
|
|
||||||
model.model = clazz
|
|
||||||
|
|
||||||
def __validate_ollama_models(self) -> None:
|
|
||||||
"""
|
|
||||||
Validate models for the Ollama provider.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
models_list = ollama.list()
|
|
||||||
availables = {model['model'] for model in models_list['models']}
|
|
||||||
not_availables: list[str] = []
|
|
||||||
|
|
||||||
for model in self.models.ollama:
|
|
||||||
if model.name in availables:
|
|
||||||
model.model = Ollama
|
|
||||||
else:
|
|
||||||
not_availables.append(model.name)
|
|
||||||
if not_availables:
|
|
||||||
log.warning(f"Ollama models not available: {not_availables}")
|
|
||||||
|
|
||||||
self.models.ollama = [model for model in self.models.ollama if model.model]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log.warning(f"Ollama is not running or not reachable: {e}")
|
|
||||||
|
|
||||||
|
|||||||
@@ -84,14 +84,16 @@ class ChatManager:
|
|||||||
label="Modello da usare"
|
label="Modello da usare"
|
||||||
)
|
)
|
||||||
provider.change(fn=self.inputs.choose_team_leader, inputs=provider, outputs=None)
|
provider.change(fn=self.inputs.choose_team_leader, inputs=provider, outputs=None)
|
||||||
|
provider.value = self.inputs.team_leader_model.label
|
||||||
|
|
||||||
style = gr.Dropdown(
|
strategy = gr.Dropdown(
|
||||||
choices=self.inputs.list_strategies_names(),
|
choices=self.inputs.list_strategies_names(),
|
||||||
value=self.inputs.strategy.label,
|
value=self.inputs.strategy.label,
|
||||||
type="index",
|
type="index",
|
||||||
label="Stile di investimento"
|
label="Stile di investimento"
|
||||||
)
|
)
|
||||||
style.change(fn=self.inputs.choose_strategy, inputs=style, outputs=None)
|
strategy.change(fn=self.inputs.choose_strategy, inputs=strategy, outputs=None)
|
||||||
|
strategy.value = self.inputs.strategy.label
|
||||||
|
|
||||||
chat = gr.ChatInterface(
|
chat = gr.ChatInterface(
|
||||||
fn=self.gradio_respond
|
fn=self.gradio_respond
|
||||||
|
|||||||
Reference in New Issue
Block a user