Riorganizza e rinomina funzioni di estrazione in moduli di mercato e notizie; migliora la gestione delle importazioni

This commit is contained in:
2025-10-03 19:40:14 +02:00
parent b85d74a662
commit 8d1cae8706
19 changed files with 119 additions and 118 deletions

View File

@@ -1,17 +1,12 @@
import gradio as gr import gradio as gr
from agno.utils.log import log_info from agno.utils.log import log_info
from dotenv import load_dotenv from dotenv import load_dotenv
from app.chat_manager import ChatManager from app.chat_manager import ChatManager
########################################
# MAIN APP & GRADIO CHAT INTERFACE
########################################
if __name__ == "__main__":
# Carica variabili dambiente (.env)
load_dotenv()
# Inizializza ChatManager if __name__ == "__main__":
# Inizializzazioni
load_dotenv()
chat = ChatManager() chat = ChatManager()
######################################## ########################################
@@ -68,16 +63,13 @@ if __name__ == "__main__":
save_btn = gr.Button("💾 Salva Chat") save_btn = gr.Button("💾 Salva Chat")
load_btn = gr.Button("📂 Carica Chat") load_btn = gr.Button("📂 Carica Chat")
# Invio messaggio # Eventi e interazioni
msg.submit(respond, inputs=[msg, chatbot], outputs=[chatbot, chatbot, msg]) msg.submit(respond, inputs=[msg, chatbot], outputs=[chatbot, chatbot, msg])
# Reset
clear_btn.click(reset_chat, inputs=None, outputs=[chatbot, chatbot]) clear_btn.click(reset_chat, inputs=None, outputs=[chatbot, chatbot])
# Salvataggio
save_btn.click(save_current_chat, inputs=None, outputs=None) save_btn.click(save_current_chat, inputs=None, outputs=None)
# Caricamento
load_btn.click(load_previous_chat, inputs=None, outputs=[chatbot, chatbot]) load_btn.click(load_previous_chat, inputs=None, outputs=[chatbot, chatbot])
server, port = ("0.0.0.0", 8000) server, port = ("0.0.0.0", 8000) # 0.0.0.0 per accesso esterno (Docker)
server_log = "localhost" if server == "0.0.0.0" else server server_log = "localhost" if server == "0.0.0.0" else server
log_info(f"Starting UPO AppAI Chat on http://{server_log}:{port}") # noqa log_info(f"Starting UPO AppAI Chat on http://{server_log}:{port}") # noqa
demo.launch(server_name=server, server_port=port, quiet=True) demo.launch(server_name=server, server_port=port, quiet=True)

View File

@@ -1,11 +1,11 @@
import os
import json import json
from typing import List, Dict import os
from app.pipeline import Pipeline from app.pipeline import Pipeline
SAVE_DIR = os.path.join(os.path.dirname(__file__), "..", "saves") SAVE_DIR = os.path.join(os.path.dirname(__file__), "..", "saves")
os.makedirs(SAVE_DIR, exist_ok=True) os.makedirs(SAVE_DIR, exist_ok=True)
class ChatManager: class ChatManager:
""" """
Gestisce la conversazione con la Pipeline: Gestisce la conversazione con la Pipeline:
@@ -16,7 +16,7 @@ class ChatManager:
def __init__(self): def __init__(self):
self.pipeline = Pipeline() self.pipeline = Pipeline()
self.history: List[Dict[str, str]] = [] # [{"role": "user"/"assistant", "content": "..."}] self.history: list[dict[str, str]] = [] # [{"role": "user"/"assistant", "content": "..."}]
def send_message(self, message: str) -> str: def send_message(self, message: str) -> str:
""" """
@@ -58,7 +58,7 @@ class ChatManager:
""" """
self.history = [] self.history = []
def get_history(self) -> List[Dict[str, str]]: def get_history(self) -> list[dict[str, str]]:
""" """
Restituisce lo storico completo della chat. Restituisce lo storico completo della chat.
""" """
@@ -71,8 +71,8 @@ class ChatManager:
def choose_style(self, index: int): def choose_style(self, index: int):
self.pipeline.choose_style(index) self.pipeline.choose_style(index)
def list_providers(self) -> List[str]: def list_providers(self) -> list[str]:
return self.pipeline.list_providers() return self.pipeline.list_providers()
def list_styles(self) -> List[str]: def list_styles(self) -> list[str]:
return self.pipeline.list_styles() return self.pipeline.list_styles()

View File

@@ -1,11 +1,11 @@
from agno.tools import Toolkit from agno.tools import Toolkit
from app.markets.base import BaseWrapper, Price, ProductInfo
from app.markets.binance import BinanceWrapper
from app.markets.coinbase import CoinBaseWrapper
from app.markets.cryptocompare import CryptoCompareWrapper
from app.markets.yfinance import YFinanceWrapper
from app.utils.market_aggregation import aggregate_history_prices, aggregate_product_info
from app.utils.wrapper_handler import WrapperHandler from app.utils.wrapper_handler import WrapperHandler
from app.utils.market_aggregation import aggregate_product_info, aggregate_history_prices
from .base import BaseWrapper, ProductInfo, Price
from .coinbase import CoinBaseWrapper
from .binance import BinanceWrapper
from .cryptocompare import CryptoCompareWrapper
from .yfinance import YFinanceWrapper
__all__ = [ "MarketAPIsTool", "BinanceWrapper", "CoinBaseWrapper", "CryptoCompareWrapper", "YFinanceWrapper", "MARKET_INSTRUCTIONS" ] __all__ = [ "MarketAPIsTool", "BinanceWrapper", "CoinBaseWrapper", "CryptoCompareWrapper", "YFinanceWrapper", "MARKET_INSTRUCTIONS" ]

View File

@@ -1,41 +1,5 @@
from pydantic import BaseModel from pydantic import BaseModel
class BaseWrapper:
"""
Base class for market API wrappers.
All market API wrappers should inherit from this class and implement the methods.
"""
def get_product(self, asset_id: str) -> 'ProductInfo':
"""
Get product information for a specific asset ID.
Args:
asset_id (str): The asset ID to retrieve information for.
Returns:
ProductInfo: An object containing product information.
"""
raise NotImplementedError("This method should be overridden by subclasses")
def get_products(self, asset_ids: list[str]) -> list['ProductInfo']:
"""
Get product information for multiple asset IDs.
Args:
asset_ids (list[str]): The list of asset IDs to retrieve information for.
Returns:
list[ProductInfo]: A list of objects containing product information.
"""
raise NotImplementedError("This method should be overridden by subclasses")
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list['Price']:
"""
Get historical price data for a specific asset ID.
Args:
asset_id (str): The asset ID to retrieve price data for.
limit (int): The maximum number of price data points to return.
Returns:
list[Price]: A list of Price objects.
"""
raise NotImplementedError("This method should be overridden by subclasses")
class ProductInfo(BaseModel): class ProductInfo(BaseModel):
""" """
@@ -59,3 +23,40 @@ class Price(BaseModel):
close: float = 0.0 close: float = 0.0
volume: float = 0.0 volume: float = 0.0
timestamp_ms: int = 0 # Timestamp in milliseconds timestamp_ms: int = 0 # Timestamp in milliseconds
class BaseWrapper:
"""
Base class for market API wrappers.
All market API wrappers should inherit from this class and implement the methods.
"""
def get_product(self, asset_id: str) -> ProductInfo:
"""
Get product information for a specific asset ID.
Args:
asset_id (str): The asset ID to retrieve information for.
Returns:
ProductInfo: An object containing product information.
"""
raise NotImplementedError("This method should be overridden by subclasses")
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
"""
Get product information for multiple asset IDs.
Args:
asset_ids (list[str]): The list of asset IDs to retrieve information for.
Returns:
list[ProductInfo]: A list of objects containing product information.
"""
raise NotImplementedError("This method should be overridden by subclasses")
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
"""
Get historical price data for a specific asset ID.
Args:
asset_id (str): The asset ID to retrieve price data for.
limit (int): The maximum number of price data points to return.
Returns:
list[Price]: A list of Price objects.
"""
raise NotImplementedError("This method should be overridden by subclasses")

View File

@@ -1,9 +1,9 @@
import os import os
from datetime import datetime
from binance.client import Client from binance.client import Client
from .base import ProductInfo, BaseWrapper, Price from app.markets.base import ProductInfo, BaseWrapper, Price
def get_product(currency: str, ticker_data: dict[str, str]) -> ProductInfo:
def extract_product(currency: str, ticker_data: dict[str, str]) -> ProductInfo:
product = ProductInfo() product = ProductInfo()
product.id = ticker_data.get('symbol') product.id = ticker_data.get('symbol')
product.symbol = ticker_data.get('symbol', '').replace(currency, '') product.symbol = ticker_data.get('symbol', '').replace(currency, '')
@@ -12,7 +12,7 @@ def get_product(currency: str, ticker_data: dict[str, str]) -> ProductInfo:
product.quote_currency = currency product.quote_currency = currency
return product return product
def get_price(kline_data: list) -> Price: def extract_price(kline_data: list) -> Price:
price = Price() price = Price()
price.open = float(kline_data[1]) price.open = float(kline_data[1])
price.high = float(kline_data[2]) price.high = float(kline_data[2])
@@ -50,7 +50,7 @@ class BinanceWrapper(BaseWrapper):
ticker_24h = self.client.get_ticker(symbol=symbol) ticker_24h = self.client.get_ticker(symbol=symbol)
ticker['volume'] = ticker_24h.get('volume', 0) # Aggiunge volume 24h ai dati del ticker ticker['volume'] = ticker_24h.get('volume', 0) # Aggiunge volume 24h ai dati del ticker
return get_product(self.currency, ticker) return extract_product(self.currency, ticker)
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
symbols = [self.__format_symbol(asset_id) for asset_id in asset_ids] symbols = [self.__format_symbol(asset_id) for asset_id in asset_ids]
@@ -61,7 +61,7 @@ class BinanceWrapper(BaseWrapper):
for t, t24 in zip(tickers, tickers_24h): for t, t24 in zip(tickers, tickers_24h):
t['volume'] = t24.get('volume', 0) t['volume'] = t24.get('volume', 0)
return [get_product(self.currency, ticker) for ticker in tickers] return [extract_product(self.currency, ticker) for ticker in tickers]
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
symbol = self.__format_symbol(asset_id) symbol = self.__format_symbol(asset_id)
@@ -72,5 +72,5 @@ class BinanceWrapper(BaseWrapper):
interval=Client.KLINE_INTERVAL_1HOUR, interval=Client.KLINE_INTERVAL_1HOUR,
limit=limit, limit=limit,
) )
return [get_price(kline) for kline in klines] return [extract_price(kline) for kline in klines]

View File

@@ -3,10 +3,10 @@ from enum import Enum
from datetime import datetime, timedelta from datetime import datetime, timedelta
from coinbase.rest import RESTClient from coinbase.rest import RESTClient
from coinbase.rest.types.product_types import Candle, GetProductResponse, Product from coinbase.rest.types.product_types import Candle, GetProductResponse, Product
from .base import ProductInfo, BaseWrapper, Price from app.markets.base import ProductInfo, BaseWrapper, Price
def get_product(product_data: GetProductResponse | Product) -> ProductInfo: def extract_product(product_data: GetProductResponse | Product) -> ProductInfo:
product = ProductInfo() product = ProductInfo()
product.id = product_data.product_id or "" product.id = product_data.product_id or ""
product.symbol = product_data.base_currency_id or "" product.symbol = product_data.base_currency_id or ""
@@ -14,7 +14,7 @@ def get_product(product_data: GetProductResponse | Product) -> ProductInfo:
product.volume_24h = float(product_data.volume_24h) if product_data.volume_24h else 0.0 product.volume_24h = float(product_data.volume_24h) if product_data.volume_24h else 0.0
return product return product
def get_price(candle_data: Candle) -> Price: def extract_price(candle_data: Candle) -> Price:
price = Price() price = Price()
price.high = float(candle_data.high) if candle_data.high else 0.0 price.high = float(candle_data.high) if candle_data.high else 0.0
price.low = float(candle_data.low) if candle_data.low else 0.0 price.low = float(candle_data.low) if candle_data.low else 0.0
@@ -64,12 +64,12 @@ class CoinBaseWrapper(BaseWrapper):
def get_product(self, asset_id: str) -> ProductInfo: def get_product(self, asset_id: str) -> ProductInfo:
asset_id = self.__format(asset_id) asset_id = self.__format(asset_id)
asset = self.client.get_product(asset_id) asset = self.client.get_product(asset_id)
return get_product(asset) return extract_product(asset)
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
all_asset_ids = [self.__format(asset_id) for asset_id in asset_ids] all_asset_ids = [self.__format(asset_id) for asset_id in asset_ids]
assets = self.client.get_products(product_ids=all_asset_ids) assets = self.client.get_products(product_ids=all_asset_ids)
return [get_product(asset) for asset in assets.products] return [extract_product(asset) for asset in assets.products]
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
asset_id = self.__format(asset_id) asset_id = self.__format(asset_id)
@@ -83,4 +83,4 @@ class CoinBaseWrapper(BaseWrapper):
end=str(int(end_time.timestamp())), end=str(int(end_time.timestamp())),
limit=limit limit=limit
) )
return [get_price(candle) for candle in data.candles] return [extract_price(candle) for candle in data.candles]

View File

@@ -1,9 +1,9 @@
import os import os
import requests import requests
from .base import ProductInfo, BaseWrapper, Price from app.markets.base import ProductInfo, BaseWrapper, Price
def get_product(asset_data: dict) -> ProductInfo: def extract_product(asset_data: dict) -> ProductInfo:
product = ProductInfo() product = ProductInfo()
product.id = asset_data.get('FROMSYMBOL', '') + '-' + asset_data.get('TOSYMBOL', '') product.id = asset_data.get('FROMSYMBOL', '') + '-' + asset_data.get('TOSYMBOL', '')
product.symbol = asset_data.get('FROMSYMBOL', '') product.symbol = asset_data.get('FROMSYMBOL', '')
@@ -12,7 +12,7 @@ def get_product(asset_data: dict) -> ProductInfo:
assert product.price > 0, "Invalid price data received from CryptoCompare" assert product.price > 0, "Invalid price data received from CryptoCompare"
return product return product
def get_price(price_data: dict) -> Price: def extract_price(price_data: dict) -> Price:
price = Price() price = Price()
price.high = float(price_data.get('high', 0)) price.high = float(price_data.get('high', 0))
price.low = float(price_data.get('low', 0)) price.low = float(price_data.get('low', 0))
@@ -53,7 +53,7 @@ class CryptoCompareWrapper(BaseWrapper):
"tsyms": self.currency "tsyms": self.currency
}) })
data = response.get('RAW', {}).get(asset_id, {}).get(self.currency, {}) data = response.get('RAW', {}).get(asset_id, {}).get(self.currency, {})
return get_product(data) return extract_product(data)
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
response = self.__request("/data/pricemultifull", params = { response = self.__request("/data/pricemultifull", params = {
@@ -64,7 +64,7 @@ class CryptoCompareWrapper(BaseWrapper):
data = response.get('RAW', {}) data = response.get('RAW', {})
for asset_id in asset_ids: for asset_id in asset_ids:
asset_data = data.get(asset_id, {}).get(self.currency, {}) asset_data = data.get(asset_id, {}).get(self.currency, {})
assets.append(get_product(asset_data)) assets.append(extract_product(asset_data))
return assets return assets
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]: def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
@@ -75,5 +75,5 @@ class CryptoCompareWrapper(BaseWrapper):
}) })
data = response.get('Data', {}).get('Data', []) data = response.get('Data', {}).get('Data', [])
prices = [get_price(price_data) for price_data in data] prices = [extract_price(price_data) for price_data in data]
return prices return prices

View File

@@ -1,9 +1,9 @@
import json import json
from agno.tools.yfinance import YFinanceTools from agno.tools.yfinance import YFinanceTools
from .base import BaseWrapper, ProductInfo, Price from app.markets.base import BaseWrapper, ProductInfo, Price
def create_product_info(stock_data: dict[str, str]) -> ProductInfo: def extract_product(stock_data: dict[str, str]) -> ProductInfo:
""" """
Converte i dati di YFinanceTools in ProductInfo. Converte i dati di YFinanceTools in ProductInfo.
""" """
@@ -15,7 +15,7 @@ def create_product_info(stock_data: dict[str, str]) -> ProductInfo:
product.quote_currency = product.id.split('-')[1] # La valuta è la parte dopo il '-' product.quote_currency = product.id.split('-')[1] # La valuta è la parte dopo il '-'
return product return product
def create_price_from_history(hist_data: dict[str, str]) -> Price: def extract_price(hist_data: dict[str, str]) -> Price:
""" """
Converte i dati storici di YFinanceTools in Price. Converte i dati storici di YFinanceTools in Price.
""" """
@@ -52,7 +52,7 @@ class YFinanceWrapper(BaseWrapper):
symbol = self._format_symbol(asset_id) symbol = self._format_symbol(asset_id)
stock_info = self.tool.get_company_info(symbol) stock_info = self.tool.get_company_info(symbol)
stock_info = json.loads(stock_info) stock_info = json.loads(stock_info)
return create_product_info(stock_info) return extract_product(stock_info)
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
products = [] products = []
@@ -75,6 +75,6 @@ class YFinanceWrapper(BaseWrapper):
for timestamp in timestamps: for timestamp in timestamps:
temp = hist_data[timestamp] temp = hist_data[timestamp]
temp['Timestamp'] = timestamp temp['Timestamp'] = timestamp
price = create_price_from_history(temp) price = extract_price(temp)
prices.append(price) prices.append(price)
return prices return prices

View File

@@ -5,8 +5,8 @@ from agno.agent import Agent
from agno.models.base import Model from agno.models.base import Model
from agno.models.google import Gemini from agno.models.google import Gemini
from agno.models.ollama import Ollama from agno.models.ollama import Ollama
from agno.utils.log import log_warning
from agno.tools import Toolkit from agno.tools import Toolkit
from agno.utils.log import log_warning
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,10 +1,10 @@
from agno.tools import Toolkit from agno.tools import Toolkit
from app.utils.wrapper_handler import WrapperHandler from app.utils.wrapper_handler import WrapperHandler
from .base import NewsWrapper, Article from app.news.base import NewsWrapper, Article
from .news_api import NewsApiWrapper from app.news.news_api import NewsApiWrapper
from .googlenews import GoogleNewsWrapper from app.news.googlenews import GoogleNewsWrapper
from .cryptopanic_api import CryptoPanicWrapper from app.news.cryptopanic_api import CryptoPanicWrapper
from .duckduckgo import DuckDuckGoWrapper from app.news.duckduckgo import DuckDuckGoWrapper
__all__ = ["NewsAPIsTool", "NEWS_INSTRUCTIONS", "NewsApiWrapper", "GoogleNewsWrapper", "CryptoPanicWrapper", "DuckDuckGoWrapper"] __all__ = ["NewsAPIsTool", "NEWS_INSTRUCTIONS", "NewsApiWrapper", "GoogleNewsWrapper", "CryptoPanicWrapper", "DuckDuckGoWrapper"]
@@ -42,6 +42,8 @@ class NewsAPIsTool(NewsWrapper, Toolkit):
tools=[ tools=[
self.get_top_headlines, self.get_top_headlines,
self.get_latest_news, self.get_latest_news,
self.get_top_headlines_aggregated,
self.get_latest_news_aggregated,
], ],
) )

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel from pydantic import BaseModel
class Article(BaseModel): class Article(BaseModel):
source: str = "" source: str = ""
time: str = "" time: str = ""

View File

@@ -1,7 +1,8 @@
import os import os
import requests import requests
from enum import Enum from enum import Enum
from .base import NewsWrapper, Article from app.news.base import NewsWrapper, Article
class CryptoPanicFilter(Enum): class CryptoPanicFilter(Enum):
RISING = "rising" RISING = "rising"
@@ -18,7 +19,7 @@ class CryptoPanicKind(Enum):
MEDIA = "media" MEDIA = "media"
ALL = "all" ALL = "all"
def get_articles(response: dict) -> list[Article]: def extract_articles(response: dict) -> list[Article]:
articles = [] articles = []
if 'results' in response: if 'results' in response:
for item in response['results']: for item in response['results']:
@@ -73,5 +74,5 @@ class CryptoPanicWrapper(NewsWrapper):
assert response.status_code == 200, f"Error fetching data: {response}" assert response.status_code == 200, f"Error fetching data: {response}"
json_response = response.json() json_response = response.json()
articles = get_articles(json_response) articles = extract_articles(json_response)
return articles[:limit] return articles[:limit]

View File

@@ -1,8 +1,9 @@
import json import json
from .base import Article, NewsWrapper
from agno.tools.duckduckgo import DuckDuckGoTools from agno.tools.duckduckgo import DuckDuckGoTools
from app.news.base import Article, NewsWrapper
def create_article(result: dict) -> Article:
def extract_article(result: dict) -> Article:
article = Article() article = Article()
article.source = result.get("source", "") article.source = result.get("source", "")
article.time = result.get("date", "") article.time = result.get("date", "")
@@ -23,10 +24,10 @@ class DuckDuckGoWrapper(NewsWrapper):
def get_top_headlines(self, limit: int = 100) -> list[Article]: def get_top_headlines(self, limit: int = 100) -> list[Article]:
results = self.tool.duckduckgo_news(self.query, max_results=limit) results = self.tool.duckduckgo_news(self.query, max_results=limit)
json_results = json.loads(results) json_results = json.loads(results)
return [create_article(result) for result in json_results] return [extract_article(result) for result in json_results]
def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: def get_latest_news(self, query: str, limit: int = 100) -> list[Article]:
results = self.tool.duckduckgo_news(query or self.query, max_results=limit) results = self.tool.duckduckgo_news(query or self.query, max_results=limit)
json_results = json.loads(results) json_results = json.loads(results)
return [create_article(result) for result in json_results] return [extract_article(result) for result in json_results]

View File

@@ -1,7 +1,8 @@
from gnews import GNews from gnews import GNews
from .base import Article, NewsWrapper from app.news.base import Article, NewsWrapper
def result_to_article(result: dict) -> Article:
def extract_article(result: dict) -> Article:
article = Article() article = Article()
article.source = result.get("source", "") article.source = result.get("source", "")
article.time = result.get("publishedAt", "") article.time = result.get("publishedAt", "")
@@ -21,7 +22,7 @@ class GoogleNewsWrapper(NewsWrapper):
articles = [] articles = []
for result in results: for result in results:
article = result_to_article(result) article = extract_article(result)
articles.append(article) articles.append(article)
return articles return articles
@@ -31,6 +32,6 @@ class GoogleNewsWrapper(NewsWrapper):
articles = [] articles = []
for result in results: for result in results:
article = result_to_article(result) article = extract_article(result)
articles.append(article) articles.append(article)
return articles return articles

View File

@@ -1,8 +1,9 @@
import os import os
import newsapi import newsapi
from .base import Article, NewsWrapper from app.news.base import Article, NewsWrapper
def result_to_article(result: dict) -> Article:
def extract_article(result: dict) -> Article:
article = Article() article = Article()
article.source = result.get("source", {}).get("name", "") article.source = result.get("source", {}).get("name", "")
article.time = result.get("publishedAt", "") article.time = result.get("publishedAt", "")
@@ -37,7 +38,7 @@ class NewsApiWrapper(NewsWrapper):
for page in range(1, pages + 1): for page in range(1, pages + 1):
headlines = self.client.get_top_headlines(q="", category=self.category, language=self.language, page_size=page_size, page=page) headlines = self.client.get_top_headlines(q="", category=self.category, language=self.language, page_size=page_size, page=page)
results = [result_to_article(article) for article in headlines.get("articles", [])] results = [extract_article(article) for article in headlines.get("articles", [])]
articles.extend(results) articles.extend(results)
return articles return articles
@@ -47,7 +48,7 @@ class NewsApiWrapper(NewsWrapper):
for page in range(1, pages + 1): for page in range(1, pages + 1):
everything = self.client.get_everything(q=query, language=self.language, sort_by="publishedAt", page_size=page_size, page=page) everything = self.client.get_everything(q=query, language=self.language, sort_by="publishedAt", page_size=page_size, page=page)
results = [result_to_article(article) for article in everything.get("articles", [])] results = [extract_article(article) for article in everything.get("articles", [])]
articles.extend(results) articles.extend(results)
return articles return articles

View File

@@ -1,11 +1,10 @@
from agno.run.agent import RunOutput from agno.run.agent import RunOutput
from agno.team import Team from agno.team import Team
from app.news import NewsAPIsTool, NEWS_INSTRUCTIONS
from app.social import SocialAPIsTool, SOCIAL_INSTRUCTIONS
from app.markets import MarketAPIsTool, MARKET_INSTRUCTIONS
from app.models import AppModels from app.models import AppModels
from app.predictor import PredictorStyle, PredictorInput, PredictorOutput, PREDICTOR_INSTRUCTIONS from app.markets import MARKET_INSTRUCTIONS, MarketAPIsTool
from app.news import NEWS_INSTRUCTIONS, NewsAPIsTool
from app.social import SOCIAL_INSTRUCTIONS, SocialAPIsTool
from app.predictor import PREDICTOR_INSTRUCTIONS, PredictorInput, PredictorOutput, PredictorStyle
class Pipeline: class Pipeline:

View File

@@ -1,7 +1,7 @@
from agno.tools import Toolkit from agno.tools import Toolkit
from app.utils.wrapper_handler import WrapperHandler from app.utils.wrapper_handler import WrapperHandler
from .base import SocialPost, SocialWrapper from app.social.base import SocialPost, SocialWrapper
from .reddit import RedditWrapper from app.social.reddit import RedditWrapper
__all__ = ["SocialAPIsTool", "SOCIAL_INSTRUCTIONS", "RedditWrapper"] __all__ = ["SocialAPIsTool", "SOCIAL_INSTRUCTIONS", "RedditWrapper"]

View File

@@ -1,7 +1,8 @@
import os import os
from praw import Reddit from praw import Reddit
from praw.models import Submission, MoreComments from praw.models import Submission, MoreComments
from .base import SocialWrapper, SocialPost, SocialComment from app.social.base import SocialWrapper, SocialPost, SocialComment
MAX_COMMENTS = 5 MAX_COMMENTS = 5
# metterne altri se necessario. # metterne altri se necessario.
@@ -21,7 +22,7 @@ SUBREDDITS = [
] ]
def create_social_post(post: Submission) -> SocialPost: def extract_post(post: Submission) -> SocialPost:
social = SocialPost() social = SocialPost()
social.time = str(post.created) social.time = str(post.created)
social.title = post.title social.title = post.title
@@ -65,4 +66,4 @@ class RedditWrapper(SocialWrapper):
def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]: def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]:
top_posts = self.subreddits.top(limit=limit, time_filter="week") top_posts = self.subreddits.top(limit=limit, time_filter="week")
return [create_social_post(post) for post in top_posts] return [extract_post(post) for post in top_posts]

View File

@@ -1,12 +1,13 @@
import inspect import inspect
import time import time
import traceback import traceback
from typing import TypeVar, Callable, Generic, Iterable, Type from typing import Callable, Generic, Iterable, Type, TypeVar
from agno.utils.log import log_warning, log_info from agno.utils.log import log_info, log_warning
W = TypeVar("W") W = TypeVar("W")
T = TypeVar("T") T = TypeVar("T")
class WrapperHandler(Generic[W]): class WrapperHandler(Generic[W]):
""" """
A handler for managing multiple wrappers with retry logic. A handler for managing multiple wrappers with retry logic.