From 38daafce9abb68f405f931ab5b577d58e0cefa0a Mon Sep 17 00:00:00 2001 From: Giacomo Bertolazzi <31776951+Berack96@users.noreply.github.com> Date: Wed, 15 Oct 2025 14:00:39 +0200 Subject: [PATCH] Refactor team management (#26) * Refactor pipeline integration * remove direct pipeline dependency from ChatManager and TelegramApp * introduce PipelineInputs for better configuration management * listener personalizzati per eventi nella funzione di interazione della pipeline * added demos for agno * USD in configs * Dockerfile better cache --- Dockerfile | 10 +- configs.yaml | 2 +- demos/{example.py => agno_agent.py} | 0 demos/agno_workflow.py | 69 ++++++++++ src/app/__main__.py | 9 +- src/app/agents/__init__.py | 5 +- src/app/agents/pipeline.py | 189 ++++++++++++++++++++++------ src/app/agents/team.py | 25 ---- src/app/configs.py | 23 ++-- src/app/interface/chat.py | 20 +-- src/app/interface/telegram_app.py | 62 ++++----- 11 files changed, 281 insertions(+), 133 deletions(-) rename demos/{example.py => agno_agent.py} (100%) create mode 100644 demos/agno_workflow.py delete mode 100644 src/app/agents/team.py diff --git a/Dockerfile b/Dockerfile index 61d4bee..8c7489d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,16 +9,16 @@ ENV PATH="/root/.local/bin:$PATH" # Configuriamo UV per usare copy mode ed evitare problemi di linking ENV UV_LINK_MODE=copy -# Copiamo i file del progetto +# Creiamo l'ambiente virtuale con tutto già presente COPY pyproject.toml ./ COPY uv.lock ./ +RUN uv sync --frozen --no-dev +ENV PYTHONPATH="./src" + +# Copiamo i file del progetto COPY LICENSE ./ COPY src/ ./src/ COPY configs.yaml ./ -# Creiamo l'ambiente virtuale con tutto già presente -RUN uv sync -ENV PYTHONPATH="/src" - # Comando di avvio dell'applicazione CMD ["uv", "run", "src/app"] diff --git a/configs.yaml b/configs.yaml index 5d70b13..c0925b8 100644 --- a/configs.yaml +++ b/configs.yaml @@ -32,7 +32,7 @@ models: api: retry_attempts: 3 retry_delay_seconds: 2 - currency: EUR + currency: USD # TODO Magari implementare un sistema per settare i providers market_providers: [BinanceWrapper, YFinanceWrapper] news_providers: [GoogleNewsWrapper, DuckDuckGoWrapper] diff --git a/demos/example.py b/demos/agno_agent.py similarity index 100% rename from demos/example.py rename to demos/agno_agent.py diff --git a/demos/agno_workflow.py b/demos/agno_workflow.py new file mode 100644 index 0000000..13a48d2 --- /dev/null +++ b/demos/agno_workflow.py @@ -0,0 +1,69 @@ +import asyncio +from agno.agent import Agent +from agno.models.ollama import Ollama +from agno.run.workflow import WorkflowRunEvent +from agno.workflow.step import Step +from agno.workflow.steps import Steps +from agno.workflow.types import StepOutput, StepInput +from agno.workflow.parallel import Parallel +from agno.workflow.workflow import Workflow + +def my_sum(a: int, b: int) -> int: + return a + b + +def my_mul(a: int, b: int) -> int: + return a * b + +def build_agent(instructions: str) -> Agent: + return Agent( + instructions=instructions, + model=Ollama(id='qwen3:1.7b'), + tools=[my_sum] + ) + +def remove_think(text: str) -> str: + thinking = text.rfind("") + if thinking != -1: + return text[thinking + len(""):].strip() + return text.strip() + +def combine_steps_output(inputs: StepInput) -> StepOutput: + parallel = inputs.get_step_content("parallel") + if not isinstance(parallel, dict): return StepOutput() + + lang = remove_think(parallel.get("Lang", "")) + answer = remove_think(parallel.get("Predict", "")) + content = f"Language: {lang}\nPhrase: {answer}" + return StepOutput(content=content) + +async def main(): + query = "Quanto fa 50 + 150 * 50?" + + s1 = Step(name="Translate", agent=build_agent(instructions="Transform in English the user query. DO NOT answer the question and output ONLY the translated question.")) + s2 = Step(name="Predict", agent=build_agent(instructions="You will be given a question in English. You can use the tools at your disposal. Answer the question and output ONLY the answer.")) + + step_a = Step(name="Lang", agent=build_agent(instructions="Detect the language from the question and output ONLY the language code. Es: 'en' for English, 'it' for Italian, 'ja' for Japanese.")) + step_b = Steps(name="Answer", steps=[s1, s2]) + step_c = Step(name="Combine", executor=combine_steps_output) + step_f = Step(name="Final", agent=build_agent(instructions="Translate the phrase in the language code provided. Respond only with the translated answer.")) + + wf = Workflow(name="Pipeline Workflow", steps=[ + Parallel(step_a, step_b, name="parallel"), # type: ignore + step_c, + step_f + ]) + + result = "" + async for event in await wf.arun(query, stream=True, stream_intermediate_steps=True): + content = getattr(event, 'content', '') + step_name = getattr(event, 'step_name', '') + + if event.event in [WorkflowRunEvent.step_completed]: + print(f"{str(event.event)} --- {step_name} --- {remove_think(content).replace('\n', '\\n')[:80]}") + if event.event in [WorkflowRunEvent.workflow_completed]: + result = remove_think(content) + print(f"\nFinal result: {result}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/app/__main__.py b/src/app/__main__.py index 0c88872..04bc1d5 100644 --- a/src/app/__main__.py +++ b/src/app/__main__.py @@ -3,22 +3,21 @@ import logging from dotenv import load_dotenv from app.configs import AppConfig from app.interface import * -from app.agents import Pipeline if __name__ == "__main__": + # ===================== load_dotenv() - configs = AppConfig.load() - pipeline = Pipeline(configs) + # ===================== - chat = ChatManager(pipeline) + chat = ChatManager() gradio = chat.gradio_build_interface() _app, local_url, share_url = gradio.launch(server_name="0.0.0.0", server_port=configs.port, quiet=True, prevent_thread_lock=True, share=configs.gradio_share) logging.info(f"UPO AppAI Chat is running on {share_url or local_url}") try: - telegram = TelegramApp(pipeline) + telegram = TelegramApp() telegram.add_miniapp_url(share_url) telegram.run() except AssertionError as e: diff --git a/src/app/agents/__init__.py b/src/app/agents/__init__.py index 7d4287b..2e78f1b 100644 --- a/src/app/agents/__init__.py +++ b/src/app/agents/__init__.py @@ -1,5 +1,4 @@ from app.agents.predictor import PredictorInput, PredictorOutput -from app.agents.team import create_team_with -from app.agents.pipeline import Pipeline +from app.agents.pipeline import Pipeline, PipelineInputs, PipelineEvent -__all__ = ["PredictorInput", "PredictorOutput", "create_team_with", "Pipeline"] +__all__ = ["PredictorInput", "PredictorOutput", "Pipeline", "PipelineInputs", "PipelineEvent"] diff --git a/src/app/agents/pipeline.py b/src/app/agents/pipeline.py index 3338cb8..cf8de3e 100644 --- a/src/app/agents/pipeline.py +++ b/src/app/agents/pipeline.py @@ -1,32 +1,59 @@ +import asyncio +from enum import Enum import logging -from app.agents.team import create_team_with +import random +from typing import Any, Callable +from agno.agent import RunEvent +from agno.team import Team, TeamRunEvent +from agno.tools.reasoning import ReasoningTools +from agno.run.workflow import WorkflowRunEvent +from agno.workflow.step import Step +from agno.workflow.workflow import Workflow + +from app.api.tools import * from app.agents.prompts import * from app.configs import AppConfig logging = logging.getLogger("pipeline") -class Pipeline: +class PipelineEvent(str, Enum): + PLANNER = "Planner" + INFO_RECOVERY = "Info Recovery" + REPORT_GENERATION = "Report Generation" + REPORT_TRANSLATION = "Report Translation" + TOOL_USED = RunEvent.tool_call_completed + + def check_event(self, event: str, step_name: str) -> bool: + return event == self.value or (WorkflowRunEvent.step_completed and step_name == self.value) + + +class PipelineInputs: """ - Coordina gli agenti di servizio (Market, News, Social) e il Predictor finale. - Il Team è orchestrato da qwen3:latest (Ollama), mentre il Predictor è dinamico - e scelto dall'utente tramite i dropdown dell'interfaccia grafica. + Classe necessaria per passare gli input alla Pipeline. + Serve per raggruppare i parametri e semplificare l'inizializzazione. """ - def __init__(self, configs: AppConfig): - self.configs = configs + def __init__(self, configs: AppConfig | None = None) -> None: + """ + Inputs per la Pipeline di agenti. + Setta i valori di default se non specificati. + """ + self.configs = configs if configs else AppConfig() - # Stato iniziale - self.leader_model = self.configs.get_model_by_name(self.configs.agents.team_leader_model) - self.team_model = self.configs.get_model_by_name(self.configs.agents.team_model) - self.strategy = self.configs.get_strategy_by_name(self.configs.agents.strategy) + agents = self.configs.agents + self.team_model = self.configs.get_model_by_name(agents.team_model) + self.team_leader_model = self.configs.get_model_by_name(agents.team_leader_model) + self.predictor_model = self.configs.get_model_by_name(agents.predictor_model) + self.strategy = self.configs.get_strategy_by_name(agents.strategy) + self.user_query = "" # ====================== # Dropdown handlers # ====================== - def choose_leader(self, index: int): + def choose_team_leader(self, index: int): """ - Sceglie il modello LLM da usare per il Team. + Sceglie il modello LLM da usare per il Team Leader. """ self.leader_model = self.configs.models.all_models[index] @@ -38,47 +65,139 @@ class Pipeline: def choose_strategy(self, index: int): """ - Sceglie la strategia da usare per il Predictor. + Sceglie la strategia da usare per il Team. """ self.strategy = self.configs.strategies[index] # ====================== # Helpers # ====================== - def list_providers(self) -> list[str]: + def list_models_names(self) -> list[str]: """ Restituisce la lista dei nomi dei modelli disponibili. """ return [model.label for model in self.configs.models.all_models] - def list_styles(self) -> list[str]: + def list_strategies_names(self) -> list[str]: """ - Restituisce la lista degli stili di previsione disponibili. + Restituisce la lista delle strategie disponibili. """ return [strat.label for strat in self.configs.strategies] + +class Pipeline: + """ + Coordina gli agenti di servizio (Market, News, Social) e il Predictor finale. + Il Team è orchestrato da qwen3:latest (Ollama), mentre il Predictor è dinamico + e scelto dall'utente tramite i dropdown dell'interfaccia grafica. + """ + + def __init__(self, inputs: PipelineInputs): + self.inputs = inputs + # ====================== # Core interaction # ====================== - def interact(self, query: str) -> str: + def interact(self, listeners: dict[RunEvent | TeamRunEvent, Callable[[PipelineEvent], None]] = {}) -> str: """ - 1. Raccoglie output dai membri del Team - 2. Aggrega output strutturati - 3. Invoca Predictor - 4. Restituisce la strategia finale + Esegue la pipeline di agenti per rispondere alla query dell'utente. + Args: + listeners: dizionario di callback per eventi specifici (opzionale) + Returns: + La risposta generata dalla pipeline. """ - # Step 1: Creazione Team - team = create_team_with(self.configs, self.team_model, self.leader_model) + return asyncio.run(self.interact_async(listeners)) - # Step 2: raccolta output dai membri del Team - logging.info(f"Pipeline received query: {query}") - # TODO migliorare prompt (?) - query = f"The user query is: {query}\n\n They requested a {self.strategy.label} investment strategy." - team_outputs = team.run(query) # type: ignore + async def interact_async(self, listeners: dict[RunEvent | TeamRunEvent, Callable[[PipelineEvent], None]] = {}) -> str: + """ + Versione asincrona che esegue la pipeline di agenti per rispondere alla query dell'utente. + Args: + listeners: dizionario di callback per eventi specifici (opzionale) + Returns: + La risposta generata dalla pipeline. + """ + run_id = random.randint(1000, 9999) # Per tracciare i log + logging.info(f"[{run_id}] Pipeline query: {self.inputs.user_query}") - # Step 3: recupero ouput - if not isinstance(team_outputs.content, str): - logging.error(f"Team output is not a string: {team_outputs.content}") - raise ValueError("Team output is not a string") - logging.info(f"Team finished") - return team_outputs.content + # Step 1: Crea gli agenti e il team + market_tool, news_tool, social_tool = self.get_tools() + market_agent = self.inputs.team_model.get_agent(instructions=MARKET_INSTRUCTIONS, name="MarketAgent", tools=[market_tool]) + news_agent = self.inputs.team_model.get_agent(instructions=NEWS_INSTRUCTIONS, name="NewsAgent", tools=[news_tool]) + social_agent = self.inputs.team_model.get_agent(instructions=SOCIAL_INSTRUCTIONS, name="SocialAgent", tools=[social_tool]) + + team = Team( + model=self.inputs.team_leader_model.get_model(COORDINATOR_INSTRUCTIONS), + name="CryptoAnalysisTeam", + tools=[ReasoningTools()], + members=[market_agent, news_agent, social_agent], + ) + + # Step 3: Crea il workflow + #query_planner = Step(name=PipelineEvent.PLANNER, agent=Agent()) + info_recovery = Step(name=PipelineEvent.INFO_RECOVERY, team=team) + #report_generation = Step(name=PipelineEvent.REPORT_GENERATION, agent=Agent()) + #report_translate = Step(name=AppEvent.REPORT_TRANSLATION, agent=Agent()) + + workflow = Workflow( + name="App Workflow", + steps=[ + #query_planner, + info_recovery, + #report_generation, + #report_translate + ] + ) + + # Step 4: Fai partire il workflow e prendi l'output + query = f"The user query is: {self.inputs.user_query}\n\n They requested a {self.inputs.strategy.label} investment strategy." + result = await self.run(workflow, query, events={}) + logging.info(f"[{run_id}] Run finished") + return result + + # ====================== + # Helpers + # ===================== + def get_tools(self) -> tuple[MarketAPIsTool, NewsAPIsTool, SocialAPIsTool]: + """ + Restituisce la lista di tools disponibili per gli agenti. + """ + api = self.inputs.configs.api + + market_tool = MarketAPIsTool(currency=api.currency) + market_tool.handler.set_retries(api.retry_attempts, api.retry_delay_seconds) + news_tool = NewsAPIsTool() + news_tool.handler.set_retries(api.retry_attempts, api.retry_delay_seconds) + social_tool = SocialAPIsTool() + social_tool.handler.set_retries(api.retry_attempts, api.retry_delay_seconds) + + return (market_tool, news_tool, social_tool) + + @classmethod + async def run(cls, workflow: Workflow, query: str, events: dict[PipelineEvent, Callable[[Any], None]]) -> str: + """ + Esegue il workflow e gestisce gli eventi tramite le callback fornite. + Args: + workflow: istanza di Workflow da eseguire + query: query dell'utente da passare al workflow + events: dizionario di callback per eventi specifici (opzionale) + Returns: + La risposta generata dal workflow. + """ + iterator = await workflow.arun(query, stream=True, stream_intermediate_steps=True) + + content = None + async for event in iterator: + step_name = getattr(event, 'step_name', '') + + for app_event, listener in events.items(): + if app_event.check_event(event.event, step_name): + listener(event) + + if event.event == WorkflowRunEvent.workflow_completed: + content = getattr(event, 'content', '') + if isinstance(content, str): + think_str = "" + think = content.rfind(think_str) + content = content[(think + len(think_str)):] if think != -1 else content + + return content if content else "No output from workflow, something went wrong." diff --git a/src/app/agents/team.py b/src/app/agents/team.py deleted file mode 100644 index 4fcad4e..0000000 --- a/src/app/agents/team.py +++ /dev/null @@ -1,25 +0,0 @@ -from agno.team import Team -from app.api.tools import * -from app.agents.prompts import * -from app.configs import AppConfig, AppModel - - -def create_team_with(configs: AppConfig, model: AppModel, coordinator: AppModel | None = None) -> Team: - - market_tool = MarketAPIsTool(currency=configs.api.currency) - market_tool.handler.set_retries(configs.api.retry_attempts, configs.api.retry_delay_seconds) - news_tool = NewsAPIsTool() - news_tool.handler.set_retries(configs.api.retry_attempts, configs.api.retry_delay_seconds) - social_tool = SocialAPIsTool() - social_tool.handler.set_retries(configs.api.retry_attempts, configs.api.retry_delay_seconds) - - market_agent = model.get_agent(instructions=MARKET_INSTRUCTIONS, name="MarketAgent", tools=[market_tool]) - news_agent = model.get_agent(instructions=NEWS_INSTRUCTIONS, name="NewsAgent", tools=[news_tool]) - social_agent = model.get_agent(instructions=SOCIAL_INSTRUCTIONS, name="SocialAgent", tools=[social_tool]) - - coordinator = coordinator or model - return Team( - model=coordinator.get_model(COORDINATOR_INSTRUCTIONS), - name="CryptoAnalysisTeam", - members=[market_agent, news_agent, social_agent], - ) diff --git a/src/app/configs.py b/src/app/configs.py index 29c2178..179ffdd 100644 --- a/src/app/configs.py +++ b/src/app/configs.py @@ -3,7 +3,6 @@ import threading import ollama import yaml import logging.config -import agno.utils.log # type: ignore from typing import Any, ClassVar from pydantic import BaseModel from agno.agent import Agent @@ -104,8 +103,6 @@ class AppConfig(BaseModel): data = yaml.safe_load(f) configs = cls(**data) - configs.set_logging_level() - configs.validate_models() log.info(f"Loaded configuration from {file_path}") return configs @@ -115,6 +112,15 @@ class AppConfig(BaseModel): cls.instance = super(AppConfig, cls).__new__(cls) return cls.instance + def __init__(self, *args: Any, **kwargs: Any) -> None: + if hasattr(self, '_initialized'): + return + + super().__init__(*args, **kwargs) + self.set_logging_level() + self.validate_models() + self._initialized = True + def get_model_by_name(self, name: str) -> AppModel: """ Retrieve a model configuration by its name. @@ -145,17 +151,6 @@ class AppConfig(BaseModel): return strat raise ValueError(f"Strategy with name '{name}' not found.") - def get_defaults(self) -> tuple[AppModel, AppModel, Strategy]: - """ - Retrieve the default team model, leader model, and strategy. - Returns: - A tuple containing the default team model (AppModel), leader model (AppModel), and strategy (Strategy). - """ - team_model = self.get_model_by_name(self.agents.team_model) - leader_model = self.get_model_by_name(self.agents.team_leader_model) - strategy = self.get_strategy_by_name(self.agents.strategy) - return team_model, leader_model, strategy - def set_logging_level(self) -> None: """ Set the logging level based on the configuration. diff --git a/src/app/interface/chat.py b/src/app/interface/chat.py index aaba2af..6881c32 100644 --- a/src/app/interface/chat.py +++ b/src/app/interface/chat.py @@ -1,7 +1,7 @@ import os import json import gradio as gr -from app.agents.pipeline import Pipeline +from app.agents.pipeline import Pipeline, PipelineInputs class ChatManager: @@ -12,9 +12,9 @@ class ChatManager: - salva e ricarica le chat """ - def __init__(self, pipeline: Pipeline): + def __init__(self): self.history: list[dict[str, str]] = [] # [{"role": "user"/"assistant", "content": "..."}] - self.pipeline = pipeline + self.inputs = PipelineInputs() def send_message(self, message: str) -> None: """ @@ -67,7 +67,11 @@ class ChatManager: ######################################## def gradio_respond(self, message: str, history: list[dict[str, str]]) -> tuple[list[dict[str, str]], list[dict[str, str]], str]: self.send_message(message) - response = self.pipeline.interact(message) + + self.inputs.user_query = message + pipeline = Pipeline(self.inputs) + response = pipeline.interact() + self.receive_message(response) history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) @@ -95,18 +99,18 @@ class ChatManager: # Dropdown provider e stile with gr.Row(): provider = gr.Dropdown( - choices=self.pipeline.list_providers(), + choices=self.inputs.list_models_names(), type="index", label="Modello da usare" ) - provider.change(fn=self.pipeline.choose_leader, inputs=provider, outputs=None) + provider.change(fn=self.inputs.choose_team_leader, inputs=provider, outputs=None) style = gr.Dropdown( - choices=self.pipeline.list_styles(), + choices=self.inputs.list_strategies_names(), type="index", label="Stile di investimento" ) - style.change(fn=self.pipeline.choose_strategy, inputs=style, outputs=None) + style.change(fn=self.inputs.choose_strategy, inputs=style, outputs=None) chatbot = gr.Chatbot(label="Conversazione", height=500, type="messages") msg = gr.Textbox(label="Scrivi la tua richiesta", placeholder="Es: Quali sono le crypto interessanti oggi?") diff --git a/src/app/interface/telegram_app.py b/src/app/interface/telegram_app.py index 3bef9d9..71ff4c8 100644 --- a/src/app/interface/telegram_app.py +++ b/src/app/interface/telegram_app.py @@ -9,8 +9,7 @@ from markdown_pdf import MarkdownPdf, Section from telegram import CallbackQuery, InlineKeyboardButton, InlineKeyboardMarkup, Message, Update, User from telegram.constants import ChatAction from telegram.ext import Application, CallbackQueryHandler, CommandHandler, ContextTypes, ConversationHandler, MessageHandler, filters -from app.agents.pipeline import Pipeline -from app.configs import AppConfig +from app.agents.pipeline import Pipeline, PipelineInputs # per per_message di ConversationHandler che rompe sempre qualunque input tu metta warnings.filterwarnings("ignore") @@ -40,22 +39,12 @@ class ConfigsChat(Enum): MODEL_OUTPUT = "Output Model" STRATEGY = "Strategy" -class ConfigsRun: - def __init__(self, configs: AppConfig): - team, leader, strategy = configs.get_defaults() - self.team_model = team - self.leader_model = leader - self.strategy = strategy - self.user_query = "" - - class TelegramApp: - def __init__(self, pipeline: Pipeline): + def __init__(self): token = os.getenv("TELEGRAM_BOT_TOKEN") assert token, "TELEGRAM_BOT_TOKEN environment variable not set" - self.user_requests: dict[User, ConfigsRun] = {} - self.pipeline = pipeline + self.user_requests: dict[User, PipelineInputs] = {} self.token = token self.create_bot() @@ -104,10 +93,10 @@ class TelegramApp: # Funzioni di utilità ######################################## async def start_message(self, user: User, query: CallbackQuery | Message) -> None: - confs = self.user_requests.setdefault(user, ConfigsRun(self.pipeline.configs)) + confs = self.user_requests.setdefault(user, PipelineInputs()) str_model_team = f"{ConfigsChat.MODEL_TEAM.value}: {confs.team_model.label}" - str_model_output = f"{ConfigsChat.MODEL_OUTPUT.value}: {confs.leader_model.label}" + str_model_output = f"{ConfigsChat.MODEL_OUTPUT.value}: {confs.team_leader_model.label}" str_strategy = f"{ConfigsChat.STRATEGY.value}: {confs.strategy.label}" msg, keyboard = ( @@ -135,8 +124,8 @@ class TelegramApp: assert update.message and update.message.from_user, "Update message or user is None" return update.message, update.message.from_user - def callback_data(self, strings: list[str]) -> str: - return QUERY_SEP.join(strings) + def build_callback_data(self, callback: str, config: ConfigsChat, labels: list[str]) -> list[tuple[str, str]]: + return [(label, QUERY_SEP.join((callback, config.value, str(i)))) for i, label in enumerate(labels)] async def __error_handler(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None: try: @@ -168,18 +157,20 @@ class TelegramApp: return await self._model_select(update, ConfigsChat.MODEL_OUTPUT) async def _model_select(self, update: Update, state: ConfigsChat, msg: str | None = None) -> int: - query, _ = await self.handle_callbackquery(update) + query, user = await self.handle_callbackquery(update) - models = [(m.label, self.callback_data([f"__select_config", str(state), m.name])) for m in self.pipeline.configs.models.all_models] + req = self.user_requests[user] + models = self.build_callback_data("__select_config", state, req.list_models_names()) inline_btns = [[InlineKeyboardButton(name, callback_data=callback_data)] for name, callback_data in models] await query.edit_message_text(msg or state.value, reply_markup=InlineKeyboardMarkup(inline_btns)) return SELECT_CONFIG async def __strategy(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: - query, _ = await self.handle_callbackquery(update) + query, user = await self.handle_callbackquery(update) - strategies = [(s.label, self.callback_data([f"__select_config", str(ConfigsChat.STRATEGY), s.name])) for s in self.pipeline.configs.strategies] + req = self.user_requests[user] + strategies = self.build_callback_data("__select_config", ConfigsChat.STRATEGY, req.list_strategies_names()) inline_btns = [[InlineKeyboardButton(name, callback_data=callback_data)] for name, callback_data in strategies] await query.edit_message_text("Select a strategy", reply_markup=InlineKeyboardMarkup(inline_btns)) @@ -190,13 +181,13 @@ class TelegramApp: logging.debug(f"@{user.username} --> {query.data}") req = self.user_requests[user] - _, state, model_name = str(query.data).split(QUERY_SEP) + _, state, index = str(query.data).split(QUERY_SEP) if state == str(ConfigsChat.MODEL_TEAM): - req.team_model = self.pipeline.configs.get_model_by_name(model_name) + req.choose_team(int(index)) if state == str(ConfigsChat.MODEL_OUTPUT): - req.leader_model = self.pipeline.configs.get_model_by_name(model_name) + req.choose_team_leader(int(index)) if state == str(ConfigsChat.STRATEGY): - req.strategy = self.pipeline.configs.get_strategy_by_name(model_name) + req.choose_strategy(int(index)) await self.start_message(user, query) return CONFIGS @@ -207,7 +198,7 @@ class TelegramApp: confs = self.user_requests[user] confs.user_query = message.text or "" - logging.info(f"@{user.username} started the team with [{confs.team_model.label}, {confs.leader_model.label}, {confs.strategy.label}]") + logging.info(f"@{user.username} started the team with [{confs.team_model.label}, {confs.team_leader_model.label}, {confs.strategy.label}]") await self.__run_team(update, confs) logging.info(f"@{user.username} team finished.") @@ -221,7 +212,7 @@ class TelegramApp: await query.edit_message_text("Conversation canceled. Use /start to begin again.") return ConversationHandler.END - async def __run_team(self, update: Update, confs: ConfigsRun) -> None: + async def __run_team(self, update: Update, inputs: PipelineInputs) -> None: if not update.message: return bot = update.get_bot() @@ -230,10 +221,10 @@ class TelegramApp: configs_str = [ 'Running with configurations: ', - f'Team: {confs.team_model.label}', - f'Output: {confs.leader_model.label}', - f'Strategy: {confs.strategy.label}', - f'Query: "{confs.user_query}"' + f'Team: {inputs.team_model.label}', + f'Output: {inputs.team_leader_model.label}', + f'Strategy: {inputs.strategy.label}', + f'Query: "{inputs.user_query}"' ] full_message = f"""```\n{'\n'.join(configs_str)}\n```\n\n""" first_message = full_message + "Generating report, please wait" @@ -243,13 +234,10 @@ class TelegramApp: # Remove user query and bot message await bot.delete_message(chat_id=chat_id, message_id=update.message.id) - self.pipeline.leader_model = confs.leader_model - self.pipeline.team_model = confs.team_model - self.pipeline.strategy = confs.strategy - # TODO migliorare messaggi di attesa await bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING) - report_content = self.pipeline.interact(confs.user_query) + pipeline = Pipeline(inputs) + report_content = await pipeline.interact_async() await msg.delete() # attach report file to the message