Refactor team management (#26)
* Refactor pipeline integration * remove direct pipeline dependency from ChatManager and TelegramApp * introduce PipelineInputs for better configuration management * listener personalizzati per eventi nella funzione di interazione della pipeline * added demos for agno * USD in configs * Dockerfile better cache
This commit was merged in pull request #26.
This commit is contained in:
committed by
GitHub
parent
d85d6ed1eb
commit
38daafce9a
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import gradio as gr
|
||||
from app.agents.pipeline import Pipeline
|
||||
from app.agents.pipeline import Pipeline, PipelineInputs
|
||||
|
||||
|
||||
class ChatManager:
|
||||
@@ -12,9 +12,9 @@ class ChatManager:
|
||||
- salva e ricarica le chat
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline: Pipeline):
|
||||
def __init__(self):
|
||||
self.history: list[dict[str, str]] = [] # [{"role": "user"/"assistant", "content": "..."}]
|
||||
self.pipeline = pipeline
|
||||
self.inputs = PipelineInputs()
|
||||
|
||||
def send_message(self, message: str) -> None:
|
||||
"""
|
||||
@@ -67,7 +67,11 @@ class ChatManager:
|
||||
########################################
|
||||
def gradio_respond(self, message: str, history: list[dict[str, str]]) -> tuple[list[dict[str, str]], list[dict[str, str]], str]:
|
||||
self.send_message(message)
|
||||
response = self.pipeline.interact(message)
|
||||
|
||||
self.inputs.user_query = message
|
||||
pipeline = Pipeline(self.inputs)
|
||||
response = pipeline.interact()
|
||||
|
||||
self.receive_message(response)
|
||||
history.append({"role": "user", "content": message})
|
||||
history.append({"role": "assistant", "content": response})
|
||||
@@ -95,18 +99,18 @@ class ChatManager:
|
||||
# Dropdown provider e stile
|
||||
with gr.Row():
|
||||
provider = gr.Dropdown(
|
||||
choices=self.pipeline.list_providers(),
|
||||
choices=self.inputs.list_models_names(),
|
||||
type="index",
|
||||
label="Modello da usare"
|
||||
)
|
||||
provider.change(fn=self.pipeline.choose_leader, inputs=provider, outputs=None)
|
||||
provider.change(fn=self.inputs.choose_team_leader, inputs=provider, outputs=None)
|
||||
|
||||
style = gr.Dropdown(
|
||||
choices=self.pipeline.list_styles(),
|
||||
choices=self.inputs.list_strategies_names(),
|
||||
type="index",
|
||||
label="Stile di investimento"
|
||||
)
|
||||
style.change(fn=self.pipeline.choose_strategy, inputs=style, outputs=None)
|
||||
style.change(fn=self.inputs.choose_strategy, inputs=style, outputs=None)
|
||||
|
||||
chatbot = gr.Chatbot(label="Conversazione", height=500, type="messages")
|
||||
msg = gr.Textbox(label="Scrivi la tua richiesta", placeholder="Es: Quali sono le crypto interessanti oggi?")
|
||||
|
||||
@@ -9,8 +9,7 @@ from markdown_pdf import MarkdownPdf, Section
|
||||
from telegram import CallbackQuery, InlineKeyboardButton, InlineKeyboardMarkup, Message, Update, User
|
||||
from telegram.constants import ChatAction
|
||||
from telegram.ext import Application, CallbackQueryHandler, CommandHandler, ContextTypes, ConversationHandler, MessageHandler, filters
|
||||
from app.agents.pipeline import Pipeline
|
||||
from app.configs import AppConfig
|
||||
from app.agents.pipeline import Pipeline, PipelineInputs
|
||||
|
||||
# per per_message di ConversationHandler che rompe sempre qualunque input tu metta
|
||||
warnings.filterwarnings("ignore")
|
||||
@@ -40,22 +39,12 @@ class ConfigsChat(Enum):
|
||||
MODEL_OUTPUT = "Output Model"
|
||||
STRATEGY = "Strategy"
|
||||
|
||||
class ConfigsRun:
|
||||
def __init__(self, configs: AppConfig):
|
||||
team, leader, strategy = configs.get_defaults()
|
||||
self.team_model = team
|
||||
self.leader_model = leader
|
||||
self.strategy = strategy
|
||||
self.user_query = ""
|
||||
|
||||
|
||||
class TelegramApp:
|
||||
def __init__(self, pipeline: Pipeline):
|
||||
def __init__(self):
|
||||
token = os.getenv("TELEGRAM_BOT_TOKEN")
|
||||
assert token, "TELEGRAM_BOT_TOKEN environment variable not set"
|
||||
|
||||
self.user_requests: dict[User, ConfigsRun] = {}
|
||||
self.pipeline = pipeline
|
||||
self.user_requests: dict[User, PipelineInputs] = {}
|
||||
self.token = token
|
||||
self.create_bot()
|
||||
|
||||
@@ -104,10 +93,10 @@ class TelegramApp:
|
||||
# Funzioni di utilità
|
||||
########################################
|
||||
async def start_message(self, user: User, query: CallbackQuery | Message) -> None:
|
||||
confs = self.user_requests.setdefault(user, ConfigsRun(self.pipeline.configs))
|
||||
confs = self.user_requests.setdefault(user, PipelineInputs())
|
||||
|
||||
str_model_team = f"{ConfigsChat.MODEL_TEAM.value}: {confs.team_model.label}"
|
||||
str_model_output = f"{ConfigsChat.MODEL_OUTPUT.value}: {confs.leader_model.label}"
|
||||
str_model_output = f"{ConfigsChat.MODEL_OUTPUT.value}: {confs.team_leader_model.label}"
|
||||
str_strategy = f"{ConfigsChat.STRATEGY.value}: {confs.strategy.label}"
|
||||
|
||||
msg, keyboard = (
|
||||
@@ -135,8 +124,8 @@ class TelegramApp:
|
||||
assert update.message and update.message.from_user, "Update message or user is None"
|
||||
return update.message, update.message.from_user
|
||||
|
||||
def callback_data(self, strings: list[str]) -> str:
|
||||
return QUERY_SEP.join(strings)
|
||||
def build_callback_data(self, callback: str, config: ConfigsChat, labels: list[str]) -> list[tuple[str, str]]:
|
||||
return [(label, QUERY_SEP.join((callback, config.value, str(i)))) for i, label in enumerate(labels)]
|
||||
|
||||
async def __error_handler(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
try:
|
||||
@@ -168,18 +157,20 @@ class TelegramApp:
|
||||
return await self._model_select(update, ConfigsChat.MODEL_OUTPUT)
|
||||
|
||||
async def _model_select(self, update: Update, state: ConfigsChat, msg: str | None = None) -> int:
|
||||
query, _ = await self.handle_callbackquery(update)
|
||||
query, user = await self.handle_callbackquery(update)
|
||||
|
||||
models = [(m.label, self.callback_data([f"__select_config", str(state), m.name])) for m in self.pipeline.configs.models.all_models]
|
||||
req = self.user_requests[user]
|
||||
models = self.build_callback_data("__select_config", state, req.list_models_names())
|
||||
inline_btns = [[InlineKeyboardButton(name, callback_data=callback_data)] for name, callback_data in models]
|
||||
|
||||
await query.edit_message_text(msg or state.value, reply_markup=InlineKeyboardMarkup(inline_btns))
|
||||
return SELECT_CONFIG
|
||||
|
||||
async def __strategy(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
|
||||
query, _ = await self.handle_callbackquery(update)
|
||||
query, user = await self.handle_callbackquery(update)
|
||||
|
||||
strategies = [(s.label, self.callback_data([f"__select_config", str(ConfigsChat.STRATEGY), s.name])) for s in self.pipeline.configs.strategies]
|
||||
req = self.user_requests[user]
|
||||
strategies = self.build_callback_data("__select_config", ConfigsChat.STRATEGY, req.list_strategies_names())
|
||||
inline_btns = [[InlineKeyboardButton(name, callback_data=callback_data)] for name, callback_data in strategies]
|
||||
|
||||
await query.edit_message_text("Select a strategy", reply_markup=InlineKeyboardMarkup(inline_btns))
|
||||
@@ -190,13 +181,13 @@ class TelegramApp:
|
||||
logging.debug(f"@{user.username} --> {query.data}")
|
||||
|
||||
req = self.user_requests[user]
|
||||
_, state, model_name = str(query.data).split(QUERY_SEP)
|
||||
_, state, index = str(query.data).split(QUERY_SEP)
|
||||
if state == str(ConfigsChat.MODEL_TEAM):
|
||||
req.team_model = self.pipeline.configs.get_model_by_name(model_name)
|
||||
req.choose_team(int(index))
|
||||
if state == str(ConfigsChat.MODEL_OUTPUT):
|
||||
req.leader_model = self.pipeline.configs.get_model_by_name(model_name)
|
||||
req.choose_team_leader(int(index))
|
||||
if state == str(ConfigsChat.STRATEGY):
|
||||
req.strategy = self.pipeline.configs.get_strategy_by_name(model_name)
|
||||
req.choose_strategy(int(index))
|
||||
|
||||
await self.start_message(user, query)
|
||||
return CONFIGS
|
||||
@@ -207,7 +198,7 @@ class TelegramApp:
|
||||
confs = self.user_requests[user]
|
||||
confs.user_query = message.text or ""
|
||||
|
||||
logging.info(f"@{user.username} started the team with [{confs.team_model.label}, {confs.leader_model.label}, {confs.strategy.label}]")
|
||||
logging.info(f"@{user.username} started the team with [{confs.team_model.label}, {confs.team_leader_model.label}, {confs.strategy.label}]")
|
||||
await self.__run_team(update, confs)
|
||||
|
||||
logging.info(f"@{user.username} team finished.")
|
||||
@@ -221,7 +212,7 @@ class TelegramApp:
|
||||
await query.edit_message_text("Conversation canceled. Use /start to begin again.")
|
||||
return ConversationHandler.END
|
||||
|
||||
async def __run_team(self, update: Update, confs: ConfigsRun) -> None:
|
||||
async def __run_team(self, update: Update, inputs: PipelineInputs) -> None:
|
||||
if not update.message: return
|
||||
|
||||
bot = update.get_bot()
|
||||
@@ -230,10 +221,10 @@ class TelegramApp:
|
||||
|
||||
configs_str = [
|
||||
'Running with configurations: ',
|
||||
f'Team: {confs.team_model.label}',
|
||||
f'Output: {confs.leader_model.label}',
|
||||
f'Strategy: {confs.strategy.label}',
|
||||
f'Query: "{confs.user_query}"'
|
||||
f'Team: {inputs.team_model.label}',
|
||||
f'Output: {inputs.team_leader_model.label}',
|
||||
f'Strategy: {inputs.strategy.label}',
|
||||
f'Query: "{inputs.user_query}"'
|
||||
]
|
||||
full_message = f"""```\n{'\n'.join(configs_str)}\n```\n\n"""
|
||||
first_message = full_message + "Generating report, please wait"
|
||||
@@ -243,13 +234,10 @@ class TelegramApp:
|
||||
# Remove user query and bot message
|
||||
await bot.delete_message(chat_id=chat_id, message_id=update.message.id)
|
||||
|
||||
self.pipeline.leader_model = confs.leader_model
|
||||
self.pipeline.team_model = confs.team_model
|
||||
self.pipeline.strategy = confs.strategy
|
||||
|
||||
# TODO migliorare messaggi di attesa
|
||||
await bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING)
|
||||
report_content = self.pipeline.interact(confs.user_query)
|
||||
pipeline = Pipeline(inputs)
|
||||
report_content = await pipeline.interact_async()
|
||||
await msg.delete()
|
||||
|
||||
# attach report file to the message
|
||||
|
||||
Reference in New Issue
Block a user