Refactor agent handling in Pipeline; add tests for report generation and team agent responses
This commit is contained in:
@@ -1,5 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
from agno.agent import Agent
|
||||
from agno.team import Team
|
||||
from agno.tools.reasoning import ReasoningTools
|
||||
from app.api.tools import *
|
||||
from app.configs import AppConfig
|
||||
from app.agents.prompts import *
|
||||
|
||||
|
||||
|
||||
@@ -11,8 +16,6 @@ class QueryOutputs(BaseModel):
|
||||
response: str
|
||||
is_crypto: bool
|
||||
|
||||
|
||||
|
||||
class PipelineInputs:
|
||||
"""
|
||||
Classe necessaria per passare gli input alla Pipeline.
|
||||
@@ -70,7 +73,46 @@ class PipelineInputs:
|
||||
"""
|
||||
return [strat.label for strat in self.configs.strategies]
|
||||
|
||||
def get_query_inputs(self) -> QueryInputs:
|
||||
"""
|
||||
Restituisce gli input per l'agente di verifica della query.
|
||||
"""
|
||||
return QueryInputs(
|
||||
user_query=self.user_query,
|
||||
strategy=self.strategy.label,
|
||||
)
|
||||
|
||||
# ======================
|
||||
# Agent getters
|
||||
# ======================
|
||||
def get_agent_team(self) -> Team:
|
||||
market, news, social = self.get_tools()
|
||||
market_agent = self.team_model.get_agent(MARKET_INSTRUCTIONS, "Market Agent", tools=[market])
|
||||
news_agent = self.team_model.get_agent(NEWS_INSTRUCTIONS, "News Agent", tools=[news])
|
||||
social_agent = self.team_model.get_agent(SOCIAL_INSTRUCTIONS, "Socials Agent", tools=[social])
|
||||
return Team(
|
||||
model=self.team_leader_model.get_model(TEAM_LEADER_INSTRUCTIONS),
|
||||
name="CryptoAnalysisTeam",
|
||||
tools=[ReasoningTools()],
|
||||
members=[market_agent, news_agent, social_agent],
|
||||
)
|
||||
|
||||
def get_agent_query_checker(self) -> Agent:
|
||||
return self.query_analyzer_model.get_agent(QUERY_CHECK_INSTRUCTIONS, "Query Check Agent", output_schema=QueryOutputs)
|
||||
|
||||
def get_agent_report_generator(self) -> Agent:
|
||||
return self.report_generation_model.get_agent(REPORT_GENERATION_INSTRUCTIONS, "Report Generator Agent")
|
||||
|
||||
def get_tools(self) -> tuple[MarketAPIsTool, NewsAPIsTool, SocialAPIsTool]:
|
||||
"""
|
||||
Restituisce la lista di tools disponibili per gli agenti.
|
||||
"""
|
||||
api = self.configs.api
|
||||
|
||||
market_tool = MarketAPIsTool(currency=api.currency)
|
||||
market_tool.handler.set_retries(api.retry_attempts, api.retry_delay_seconds)
|
||||
news_tool = NewsAPIsTool()
|
||||
news_tool.handler.set_retries(api.retry_attempts, api.retry_delay_seconds)
|
||||
social_tool = SocialAPIsTool()
|
||||
social_tool.handler.set_retries(api.retry_attempts, api.retry_delay_seconds)
|
||||
return market_tool, news_tool, social_tool
|
||||
@@ -4,14 +4,10 @@ import logging
|
||||
import random
|
||||
from typing import Any, Callable
|
||||
from agno.agent import RunEvent
|
||||
from agno.team import Team
|
||||
from agno.tools.reasoning import ReasoningTools
|
||||
from agno.run.workflow import WorkflowRunEvent
|
||||
from agno.workflow.types import StepInput, StepOutput
|
||||
from agno.workflow.step import Step
|
||||
from agno.workflow.workflow import Workflow
|
||||
from app.api.tools import *
|
||||
from app.agents.prompts import *
|
||||
from app.agents.core import *
|
||||
|
||||
logging = logging.getLogger("pipeline")
|
||||
@@ -91,28 +87,18 @@ class Pipeline:
|
||||
L'istanza di Workflow costruita.
|
||||
"""
|
||||
# Step 1: Crea gli agenti e il team
|
||||
q_check_agent = self.inputs.query_analyzer_model.get_agent(instructions=QUERY_CHECK_INSTRUCTIONS, name="QueryCheckAgent", output_schema=QueryOutputs)
|
||||
report_agent = self.inputs.report_generation_model.get_agent(instructions=REPORT_GENERATION_INSTRUCTIONS, name="ReportGeneratorAgent")
|
||||
|
||||
market_tool, news_tool, social_tool = self.get_tools()
|
||||
market_agent = self.inputs.team_model.get_agent(instructions=MARKET_INSTRUCTIONS, name="MarketAgent", tools=[market_tool])
|
||||
news_agent = self.inputs.team_model.get_agent(instructions=NEWS_INSTRUCTIONS, name="NewsAgent", tools=[news_tool])
|
||||
social_agent = self.inputs.team_model.get_agent(instructions=SOCIAL_INSTRUCTIONS, name="SocialAgent", tools=[social_tool])
|
||||
team = Team(
|
||||
model=self.inputs.team_leader_model.get_model(COORDINATOR_INSTRUCTIONS),
|
||||
name="CryptoAnalysisTeam",
|
||||
tools=[ReasoningTools()],
|
||||
members=[market_agent, news_agent, social_agent],
|
||||
)
|
||||
team = self.inputs.get_agent_team()
|
||||
query_check = self.inputs.get_agent_query_checker()
|
||||
report = self.inputs.get_agent_report_generator()
|
||||
|
||||
# Step 2: Crea gli steps
|
||||
def condition_query_ok(step_input: StepInput) -> StepOutput:
|
||||
val = step_input.previous_step_content
|
||||
return StepOutput(stop=not val.is_crypto) if isinstance(val, QueryOutputs) else StepOutput(stop=True)
|
||||
|
||||
query_check = Step(name=PipelineEvent.QUERY_CHECK, agent=q_check_agent)
|
||||
query_check = Step(name=PipelineEvent.QUERY_CHECK, agent=query_check)
|
||||
info_recovery = Step(name=PipelineEvent.INFO_RECOVERY, team=team)
|
||||
report_generation = Step(name=PipelineEvent.REPORT_GENERATION, agent=report_agent)
|
||||
report_generation = Step(name=PipelineEvent.REPORT_GENERATION, agent=report)
|
||||
|
||||
# Step 3: Ritorna il workflow completo
|
||||
return Workflow(name="App Workflow", steps=[
|
||||
@@ -122,22 +108,6 @@ class Pipeline:
|
||||
report_generation
|
||||
])
|
||||
|
||||
|
||||
def get_tools(self) -> tuple[MarketAPIsTool, NewsAPIsTool, SocialAPIsTool]:
|
||||
"""
|
||||
Restituisce la lista di tools disponibili per gli agenti.
|
||||
"""
|
||||
api = self.inputs.configs.api
|
||||
|
||||
market_tool = MarketAPIsTool(currency=api.currency)
|
||||
market_tool.handler.set_retries(api.retry_attempts, api.retry_delay_seconds)
|
||||
news_tool = NewsAPIsTool()
|
||||
news_tool.handler.set_retries(api.retry_attempts, api.retry_delay_seconds)
|
||||
social_tool = SocialAPIsTool()
|
||||
social_tool.handler.set_retries(api.retry_attempts, api.retry_delay_seconds)
|
||||
|
||||
return (market_tool, news_tool, social_tool)
|
||||
|
||||
@classmethod
|
||||
async def run(cls, workflow: Workflow, query: QueryInputs, events: list[tuple[PipelineEvent, Callable[[Any], None]]]) -> str:
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,7 @@ def __load_prompt(file_name: str) -> str:
|
||||
file_path = __PROMPTS_PATH / file_name
|
||||
return file_path.read_text(encoding='utf-8').strip()
|
||||
|
||||
COORDINATOR_INSTRUCTIONS = __load_prompt("team_leader.txt")
|
||||
TEAM_LEADER_INSTRUCTIONS = __load_prompt("team_leader.txt")
|
||||
MARKET_INSTRUCTIONS = __load_prompt("team_market.txt")
|
||||
NEWS_INSTRUCTIONS = __load_prompt("team_news.txt")
|
||||
SOCIAL_INSTRUCTIONS = __load_prompt("team_social.txt")
|
||||
@@ -14,7 +14,7 @@ QUERY_CHECK_INSTRUCTIONS = __load_prompt("query_check.txt")
|
||||
REPORT_GENERATION_INSTRUCTIONS = __load_prompt("report_generation.txt")
|
||||
|
||||
__all__ = [
|
||||
"COORDINATOR_INSTRUCTIONS",
|
||||
"TEAM_LEADER_INSTRUCTIONS",
|
||||
"MARKET_INSTRUCTIONS",
|
||||
"NEWS_INSTRUCTIONS",
|
||||
"SOCIAL_INSTRUCTIONS",
|
||||
|
||||
31
tests/agents/test_report.py
Normal file
31
tests/agents/test_report.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
from app.agents.prompts import REPORT_GENERATION_INSTRUCTIONS
|
||||
from app.configs import AppConfig
|
||||
|
||||
|
||||
class TestReportGenerationAgent:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
self.configs = AppConfig.load()
|
||||
self.model = self.configs.get_model_by_name("qwen3:1.7b")
|
||||
self.agent = self.model.get_agent(REPORT_GENERATION_INSTRUCTIONS)
|
||||
|
||||
def test_report_generation(self):
|
||||
sample_data = """
|
||||
The analysis reported from the Market Agent have highlighted the following key metrics for the cryptocurrency market:
|
||||
Bitcoin (BTC) has shown strong performance over the last 24 hours with a price of $30,000 and a Market Cap of $600 Billion
|
||||
Ethereum (ETH) is currently priced at $2,000 with a Market Cap of $250 Billion and a 24h Volume of $20 Billion.
|
||||
The overall market sentiment is bullish with a 5% increase in total market capitalization.
|
||||
No significant regulatory news has been reported and the social media sentiment remains unknown.
|
||||
"""
|
||||
|
||||
response = self.agent.run(sample_data) #type: ignore
|
||||
assert response is not None
|
||||
assert response.content is not None
|
||||
content = response.content
|
||||
assert isinstance(content, str)
|
||||
print(content)
|
||||
assert "Bitcoin" in content
|
||||
assert "Ethereum" in content
|
||||
assert "Summary" in content
|
||||
|
||||
38
tests/agents/test_team.py
Normal file
38
tests/agents/test_team.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from app.agents.core import PipelineInputs
|
||||
from app.agents.prompts import *
|
||||
from app.configs import AppConfig
|
||||
|
||||
|
||||
# fix warning about no event loop
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestTeamAgent:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
self.configs = AppConfig.load()
|
||||
self.configs.agents.team_model = "qwen3:1.7b"
|
||||
self.configs.agents.team_leader_model = "qwen3:1.7b"
|
||||
self.inputs = PipelineInputs(self.configs)
|
||||
self.team = self.inputs.get_agent_team()
|
||||
|
||||
def test_team_agent_response(self):
|
||||
self.inputs.user_query = "Is Bitcoin a good investment now?"
|
||||
inputs = self.inputs.get_query_inputs()
|
||||
response = self.team.run(inputs) # type: ignore
|
||||
|
||||
assert response is not None
|
||||
assert response.content is not None
|
||||
content = response.content
|
||||
print(content)
|
||||
assert isinstance(content, str)
|
||||
assert "Bitcoin" in content
|
||||
assert False
|
||||
Reference in New Issue
Block a user