Fix configs validation #66
@@ -1,11 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from app.configs import AppConfig
|
from app.configs import AppConfig
|
||||||
from app.interface import *
|
from app.interface import *
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
# =====================
|
# =====================
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
configs = AppConfig.load()
|
configs = AppConfig.load()
|
||||||
@@ -30,3 +32,7 @@ if __name__ == "__main__":
|
|||||||
logging.info("Shutting down due to KeyboardInterrupt")
|
logging.info("Shutting down due to KeyboardInterrupt")
|
||||||
finally:
|
finally:
|
||||||
gradio.close()
|
gradio.close()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Application failed to start: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|||||||
@@ -78,36 +78,38 @@ class Strategy(BaseModel):
|
|||||||
|
|
||||||
class ModelsConfig(BaseModel):
|
class ModelsConfig(BaseModel):
|
||||||
gemini: list[AppModel] = [AppModel()]
|
gemini: list[AppModel] = [AppModel()]
|
||||||
gpt: list[AppModel] = [AppModel(name="gpt-4o", label="OpenAIChat")]
|
gpt: list[AppModel] = []
|
||||||
mistral: list[AppModel] = [AppModel(name="mistral-large-latest", label="Mistral")]
|
mistral: list[AppModel] = []
|
||||||
deepseek: list[AppModel] = [AppModel(name="deepseek-chat", label="DeepSeek")]
|
deepseek: list[AppModel] = []
|
||||||
# xai: list[AppModel] = [AppModel(name="grok-3", label="xAI")]
|
|
||||||
ollama: list[AppModel] = []
|
ollama: list[AppModel] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_models(self) -> list[AppModel]:
|
def all_models(self) -> list[AppModel]:
|
||||||
return self.gemini + self.ollama + self.gpt + self.mistral + self.deepseek # + self.xai
|
return self.gemini + self.ollama + self.gpt + self.mistral + self.deepseek
|
||||||
|
|
||||||
def validate_models(self) -> None:
|
def validate_models(self) -> None:
|
||||||
"""
|
"""
|
||||||
Validate the configured models for each provider.
|
Validate the configured models for each supported provider.
|
||||||
"""
|
"""
|
||||||
self.__validate_online_models(self.gemini, clazz=Gemini, key="GOOGLE_API_KEY")
|
self.__validate_online_models(self.gemini, clazz=Gemini, key="GOOGLE_API_KEY")
|
||||||
self.__validate_online_models(self.gpt, clazz=OpenAIChat, key="OPENAI_API_KEY")
|
self.__validate_online_models(self.gpt, clazz=OpenAIChat, key="OPENAI_API_KEY")
|
||||||
self.__validate_online_models(self.mistral, clazz=MistralChat, key="MISTRAL_API_KEY")
|
self.__validate_online_models(self.mistral, clazz=MistralChat, key="MISTRAL_API_KEY")
|
||||||
self.__validate_online_models(self.deepseek, clazz=DeepSeek, key="DEEPSEEK_API_KEY")
|
self.__validate_online_models(self.deepseek, clazz=DeepSeek, key="DEEPSEEK_API_KEY")
|
||||||
# self.__validate_online_models(self.xai, clazz=xAI, key="XAI_API_KEY")
|
|
||||||
|
|
||||||
self.__validate_ollama_models()
|
self.__validate_ollama_models()
|
||||||
|
|
||||||
def __validate_online_models(self, models: list[AppModel], clazz: type[Model], key: str | None = None) -> None:
|
def __validate_online_models(self, models: list[AppModel], clazz: type[Model], key: str | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Validate models for online providers like Gemini.
|
Validate models for online providers that require an API key.
|
||||||
|
If the models list is empty, no validation is performed and the method returns immediately.
|
||||||
|
If the API key is not set, the models list will be cleared.
|
||||||
Args:
|
Args:
|
||||||
models: list of AppModel instances to validate
|
models: list of AppModel instances to validate
|
||||||
clazz: class of the model (e.g. Gemini)
|
clazz: class of the model (e.g. Gemini)
|
||||||
key: API key required for the provider (optional)
|
key: API key required for the provider (optional)
|
||||||
"""
|
"""
|
||||||
|
if not models:
|
||||||
|
return
|
||||||
|
|
||||||
if key and os.getenv(key) is None:
|
if key and os.getenv(key) is None:
|
||||||
log.warning(f"No {key} set in environment variables for {clazz.__name__}.")
|
log.warning(f"No {key} set in environment variables for {clazz.__name__}.")
|
||||||
models.clear()
|
models.clear()
|
||||||
@@ -131,7 +133,7 @@ class ModelsConfig(BaseModel):
|
|||||||
else:
|
else:
|
||||||
not_availables.append(model.name)
|
not_availables.append(model.name)
|
||||||
if not_availables:
|
if not_availables:
|
||||||
log.warning(f"Ollama models not available: {not_availables}")
|
log.warning(f"Ollama models not available, but defined in configs: {not_availables}")
|
||||||
|
|
||||||
self.ollama = [model for model in self.ollama if model.model]
|
self.ollama = [model for model in self.ollama if model.model]
|
||||||
|
|
||||||
@@ -147,6 +149,28 @@ class AgentsConfigs(BaseModel):
|
|||||||
query_analyzer_model: str = "gemini-2.0-flash"
|
query_analyzer_model: str = "gemini-2.0-flash"
|
||||||
|
|
|||||||
report_generation_model: str = "gemini-2.0-flash"
|
report_generation_model: str = "gemini-2.0-flash"
|
||||||
|
|
||||||
|
def validate_defaults(self, configs: 'AppConfig') -> None:
|
||||||
|
"""
|
||||||
|
Validate that the default models and strategy exist in the provided configurations.
|
||||||
|
If any default is not found, a ValueError is raised.
|
||||||
|
Args:
|
||||||
|
configs: the AppConfig instance containing models and strategies.
|
||||||
|
Raises:
|
||||||
|
ValueError if any default model or strategy is not found.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
configs.get_strategy_by_name(self.strategy)
|
||||||
|
except ValueError as e:
|
||||||
|
log.error(f"Default strategy '{self.strategy}' not found in configurations.")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
for model_name in [self.team_model, self.team_leader_model, self.query_analyzer_model, self.report_generation_model]:
|
||||||
|
try:
|
||||||
|
configs.get_model_by_name(model_name)
|
||||||
|
except ValueError as e:
|
||||||
|
log.error(f"Default agent model '{model_name}' not found in configurations.")
|
||||||
|
raise e
|
||||||
|
|
||||||
class AppConfig(BaseModel):
|
class AppConfig(BaseModel):
|
||||||
port: int = 8000
|
port: int = 8000
|
||||||
gradio_share: bool = False
|
gradio_share: bool = False
|
||||||
@@ -188,6 +212,7 @@ class AppConfig(BaseModel):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.set_logging_level()
|
self.set_logging_level()
|
||||||
self.models.validate_models()
|
self.models.validate_models()
|
||||||
|
self.agents.validate_defaults(self)
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def get_model_by_name(self, name: str) -> AppModel:
|
def get_model_by_name(self, name: str) -> AppModel:
|
||||||
|
|||||||
Reference in New Issue
Block a user
The docstring parameters don't match the actual method signature. The method accepts a single
configs: AppConfigparameter, not separatemodelsandstrategiesparameters. Update the Args section to reflect the actual parameter.