* Modificati caricamenti dei provider google e anthropic.
* Modificata _predict_google: ora è funzionante. + Aggiunto debug per provider attivi trovati. # NB: Sia le chiamate che le chiamate agli modelli potrebbero non essere aggiornati (non ho ancora letto la documentazione recente e potrebbero essere da modificare).
This commit is contained in:
@@ -26,13 +26,17 @@ class PredictorAgent:
|
|||||||
# Anthropic
|
# Anthropic
|
||||||
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
if anthropic_key:
|
if anthropic_key:
|
||||||
client = anthropic.Client(api_key=anthropic_key)
|
client = anthropic.Anthropic(api_key=anthropic_key)
|
||||||
self.providers["anthropic"] = {"type": "anthropic", "client": client, "model": "claude-3"}
|
self.providers["anthropic"] = {
|
||||||
|
"type": "anthropic",
|
||||||
|
"client": client,
|
||||||
|
"model": "claude-3-haiku-20240307"
|
||||||
|
}
|
||||||
|
|
||||||
# Google Gemini
|
# Google Gemini
|
||||||
google_key = os.getenv("GEMINI_API_KEY")
|
google_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||||
if google_key:
|
if google_key:
|
||||||
client = Client(credentials={"api_key": google_key})
|
client = Client(api_key=google_key)
|
||||||
self.providers["google"] = {"type": "google", "client": client, "model": "gemini-1.5-flash"}
|
self.providers["google"] = {"type": "google", "client": client, "model": "gemini-1.5-flash"}
|
||||||
|
|
||||||
# DeepSeek
|
# DeepSeek
|
||||||
@@ -40,6 +44,8 @@ class PredictorAgent:
|
|||||||
if deepseek_key:
|
if deepseek_key:
|
||||||
self.providers["deepseek"] = {"type": "deepseek", "api_key": deepseek_key, "model": "deepseek-chat"}
|
self.providers["deepseek"] = {"type": "deepseek", "api_key": deepseek_key, "model": "deepseek-chat"}
|
||||||
|
|
||||||
|
print("✅ Providers attivi:", list(self.providers.keys()))
|
||||||
|
|
||||||
def predict(self, data, sentiment, style="conservative", provider="mock"):
|
def predict(self, data, sentiment, style="conservative", provider="mock"):
|
||||||
provider = provider.lower()
|
provider = provider.lower()
|
||||||
if provider == "mock" or provider not in self.providers:
|
if provider == "mock" or provider not in self.providers:
|
||||||
@@ -114,11 +120,13 @@ class PredictorAgent:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _predict_google(prompt, client, model):
|
def _predict_google(prompt, client, model):
|
||||||
response = client.generate_text(
|
response = client.models.generate_content(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
contents=prompt,
|
||||||
max_output_tokens=300,
|
config={
|
||||||
temperature=0.7
|
"temperature": 0.7,
|
||||||
|
"max_output_tokens": 300
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return response.text.strip()
|
return response.text.strip()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user