Fix configs validation #66
@@ -78,36 +78,38 @@ class Strategy(BaseModel):
|
|||||||
|
|
||||||
class ModelsConfig(BaseModel):
|
class ModelsConfig(BaseModel):
|
||||||
gemini: list[AppModel] = [AppModel()]
|
gemini: list[AppModel] = [AppModel()]
|
||||||
gpt: list[AppModel] = [AppModel(name="gpt-4o", label="OpenAIChat")]
|
gpt: list[AppModel] = []
|
||||||
mistral: list[AppModel] = [AppModel(name="mistral-large-latest", label="Mistral")]
|
mistral: list[AppModel] = []
|
||||||
deepseek: list[AppModel] = [AppModel(name="deepseek-chat", label="DeepSeek")]
|
deepseek: list[AppModel] = []
|
||||||
# xai: list[AppModel] = [AppModel(name="grok-3", label="xAI")]
|
|
||||||
ollama: list[AppModel] = []
|
ollama: list[AppModel] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_models(self) -> list[AppModel]:
|
def all_models(self) -> list[AppModel]:
|
||||||
return self.gemini + self.ollama + self.gpt + self.mistral + self.deepseek # + self.xai
|
return self.gemini + self.ollama + self.gpt + self.mistral + self.deepseek
|
||||||
|
|
||||||
def validate_models(self) -> None:
|
def validate_models(self) -> None:
|
||||||
"""
|
"""
|
||||||
Validate the configured models for each provider.
|
Validate the configured models for each supported provider.
|
||||||
"""
|
"""
|
||||||
self.__validate_online_models(self.gemini, clazz=Gemini, key="GOOGLE_API_KEY")
|
self.__validate_online_models(self.gemini, clazz=Gemini, key="GOOGLE_API_KEY")
|
||||||
self.__validate_online_models(self.gpt, clazz=OpenAIChat, key="OPENAI_API_KEY")
|
self.__validate_online_models(self.gpt, clazz=OpenAIChat, key="OPENAI_API_KEY")
|
||||||
self.__validate_online_models(self.mistral, clazz=MistralChat, key="MISTRAL_API_KEY")
|
self.__validate_online_models(self.mistral, clazz=MistralChat, key="MISTRAL_API_KEY")
|
||||||
self.__validate_online_models(self.deepseek, clazz=DeepSeek, key="DEEPSEEK_API_KEY")
|
self.__validate_online_models(self.deepseek, clazz=DeepSeek, key="DEEPSEEK_API_KEY")
|
||||||
# self.__validate_online_models(self.xai, clazz=xAI, key="XAI_API_KEY")
|
|
||||||
|
|
||||||
self.__validate_ollama_models()
|
self.__validate_ollama_models()
|
||||||
|
|
||||||
def __validate_online_models(self, models: list[AppModel], clazz: type[Model], key: str | None = None) -> None:
|
def __validate_online_models(self, models: list[AppModel], clazz: type[Model], key: str | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Validate models for online providers like Gemini.
|
Validate models for online providers that require an API key.
|
||||||
|
If the models list is empty, no validation is performed and the method returns immediately.
|
||||||
|
If the API key is not set, the models list will be cleared.
|
||||||
Args:
|
Args:
|
||||||
models: list of AppModel instances to validate
|
models: list of AppModel instances to validate
|
||||||
clazz: class of the model (e.g. Gemini)
|
clazz: class of the model (e.g. Gemini)
|
||||||
key: API key required for the provider (optional)
|
key: API key required for the provider (optional)
|
||||||
"""
|
"""
|
||||||
|
if not models:
|
||||||
|
return
|
||||||
|
|
||||||
if key and os.getenv(key) is None:
|
if key and os.getenv(key) is None:
|
||||||
log.warning(f"No {key} set in environment variables for {clazz.__name__}.")
|
log.warning(f"No {key} set in environment variables for {clazz.__name__}.")
|
||||||
models.clear()
|
models.clear()
|
||||||
@@ -131,7 +133,7 @@ class ModelsConfig(BaseModel):
|
|||||||
else:
|
else:
|
||||||
not_availables.append(model.name)
|
not_availables.append(model.name)
|
||||||
if not_availables:
|
if not_availables:
|
||||||
log.warning(f"Ollama models not available: {not_availables}")
|
log.warning(f"Ollama models not available, but defined in configs: {not_availables}")
|
||||||
|
|
||||||
self.ollama = [model for model in self.ollama if model.model]
|
self.ollama = [model for model in self.ollama if model.model]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user