From f1bf00c75931a50c88f13974a7cda291df55cff0 Mon Sep 17 00:00:00 2001 From: Berack96 Date: Thu, 9 Oct 2025 23:11:57 +0200 Subject: [PATCH] Refactor: Introduce WrapperHandler for managing API wrappers with retry logic; update related modules and tests --- src/app/api/__init__.py | 0 src/app/api/base/markets.py | 64 ++++++++++++++++++++++ src/app/api/markets/__init__.py | 16 +++--- src/app/api/news/__init__.py | 12 ++--- src/app/api/social/__init__.py | 8 +-- src/app/{utils => api}/wrapper_handler.py | 0 src/app/utils/__init__.py | 4 +- src/app/utils/market_aggregation.py | 65 ----------------------- tests/tools/test_news_tool.py | 8 +-- tests/tools/test_socials_tool.py | 4 +- tests/utils/test_market_aggregator.py | 3 +- tests/utils/test_wrapper_handler.py | 2 +- 12 files changed, 91 insertions(+), 95 deletions(-) create mode 100644 src/app/api/__init__.py rename src/app/{utils => api}/wrapper_handler.py (100%) delete mode 100644 src/app/utils/market_aggregation.py diff --git a/src/app/api/__init__.py b/src/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/api/base/markets.py b/src/app/api/base/markets.py index cd00879..e6a0657 100644 --- a/src/app/api/base/markets.py +++ b/src/app/api/base/markets.py @@ -1,3 +1,4 @@ +import statistics from datetime import datetime from pydantic import BaseModel @@ -81,3 +82,66 @@ class MarketWrapper: list[Price]: A list of Price objects. """ raise NotImplementedError("This method should be overridden by subclasses") + + +def aggregate_history_prices(prices: dict[str, list[Price]]) -> list[Price]: + """ + Aggrega i prezzi storici per symbol calcolando la media. + Args: + prices (dict[str, list[Price]]): Mappa provider -> lista di Price + Returns: + list[Price]: Lista di Price aggregati per timestamp + """ + + # Costruiamo una mappa timestamp -> lista di Price + timestamped_prices: dict[str, list[Price]] = {} + for _, price_list in prices.items(): + for price in price_list: + timestamped_prices.setdefault(price.timestamp, []).append(price) + + # Ora aggregiamo i prezzi per ogni timestamp + aggregated_prices: list[Price] = [] + for time, price_list in timestamped_prices.items(): + price = Price() + price.timestamp = time + price.high = statistics.mean([p.high for p in price_list]) + price.low = statistics.mean([p.low for p in price_list]) + price.open = statistics.mean([p.open for p in price_list]) + price.close = statistics.mean([p.close for p in price_list]) + price.volume = statistics.mean([p.volume for p in price_list]) + aggregated_prices.append(price) + return aggregated_prices + +def aggregate_product_info(products: dict[str, list[ProductInfo]]) -> list[ProductInfo]: + """ + Aggrega una lista di ProductInfo per symbol. + Args: + products (dict[str, list[ProductInfo]]): Mappa provider -> lista di ProductInfo + Returns: + list[ProductInfo]: Lista di ProductInfo aggregati per symbol + """ + + # Costruzione mappa symbol -> lista di ProductInfo + symbols_infos: dict[str, list[ProductInfo]] = {} + for _, product_list in products.items(): + for product in product_list: + symbols_infos.setdefault(product.symbol, []).append(product) + + # Aggregazione per ogni symbol + aggregated_products: list[ProductInfo] = [] + for symbol, product_list in symbols_infos.items(): + product = ProductInfo() + + product.id = f"{symbol}_AGGREGATED" + product.symbol = symbol + product.currency = next(p.currency for p in product_list if p.currency) + + volume_sum = sum(p.volume_24h for p in product_list) + product.volume_24h = volume_sum / len(product_list) if product_list else 0.0 + + prices = sum(p.price * p.volume_24h for p in product_list) + product.price = (prices / volume_sum) if volume_sum > 0 else 0.0 + + aggregated_products.append(product) + return aggregated_products + diff --git a/src/app/api/markets/__init__.py b/src/app/api/markets/__init__.py index 2b0de19..784913a 100644 --- a/src/app/api/markets/__init__.py +++ b/src/app/api/markets/__init__.py @@ -1,10 +1,10 @@ from agno.tools import Toolkit -from app.api.base.markets import MarketWrapper, Price, ProductInfo +from app.api.wrapper_handler import WrapperHandler +from app.api.base.markets import MarketWrapper, Price, ProductInfo, aggregate_history_prices, aggregate_product_info from app.api.markets.binance import BinanceWrapper from app.api.markets.coinbase import CoinBaseWrapper from app.api.markets.cryptocompare import CryptoCompareWrapper from app.api.markets.yfinance import YFinanceWrapper -from app.utils import aggregate_history_prices, aggregate_product_info, WrapperHandler __all__ = [ "MarketAPIsTool", "BinanceWrapper", "CoinBaseWrapper", "CryptoCompareWrapper", "YFinanceWrapper", "ProductInfo", "Price" ] @@ -34,7 +34,7 @@ class MarketAPIsTool(MarketWrapper, Toolkit): """ kwargs = {"currency": currency or "USD"} wrappers: list[type[MarketWrapper]] = [BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper] - self.wrappers = WrapperHandler.build_wrappers(wrappers, kwargs=kwargs) + self.handler = WrapperHandler.build_wrappers(wrappers, kwargs=kwargs) Toolkit.__init__( # type: ignore self, @@ -49,11 +49,11 @@ class MarketAPIsTool(MarketWrapper, Toolkit): ) def get_product(self, asset_id: str) -> ProductInfo: - return self.wrappers.try_call(lambda w: w.get_product(asset_id)) + return self.handler.try_call(lambda w: w.get_product(asset_id)) def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: - return self.wrappers.try_call(lambda w: w.get_products(asset_ids)) + return self.handler.try_call(lambda w: w.get_products(asset_ids)) def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]: - return self.wrappers.try_call(lambda w: w.get_historical_prices(asset_id, limit)) + return self.handler.try_call(lambda w: w.get_historical_prices(asset_id, limit)) def get_products_aggregated(self, asset_ids: list[str]) -> list[ProductInfo]: @@ -67,7 +67,7 @@ class MarketAPIsTool(MarketWrapper, Toolkit): Raises: Exception: If all wrappers fail to provide results. """ - all_products = self.wrappers.try_call_all(lambda w: w.get_products(asset_ids)) + all_products = self.handler.try_call_all(lambda w: w.get_products(asset_ids)) return aggregate_product_info(all_products) def get_historical_prices_aggregated(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: @@ -82,5 +82,5 @@ class MarketAPIsTool(MarketWrapper, Toolkit): Raises: Exception: If all wrappers fail to provide results. """ - all_prices = self.wrappers.try_call_all(lambda w: w.get_historical_prices(asset_id, limit)) + all_prices = self.handler.try_call_all(lambda w: w.get_historical_prices(asset_id, limit)) return aggregate_history_prices(all_prices) diff --git a/src/app/api/news/__init__.py b/src/app/api/news/__init__.py index e9701b8..a66cf05 100644 --- a/src/app/api/news/__init__.py +++ b/src/app/api/news/__init__.py @@ -1,5 +1,5 @@ from agno.tools import Toolkit -from app.utils import WrapperHandler +from app.api.wrapper_handler import WrapperHandler from app.api.base.news import NewsWrapper, Article from app.api.news.news_api import NewsApiWrapper from app.api.news.googlenews import GoogleNewsWrapper @@ -34,7 +34,7 @@ class NewsAPIsTool(NewsWrapper, Toolkit): - CryptoPanicWrapper. """ wrappers: list[type[NewsWrapper]] = [GoogleNewsWrapper, DuckDuckGoWrapper, NewsApiWrapper, CryptoPanicWrapper] - self.wrapper_handler = WrapperHandler.build_wrappers(wrappers) + self.handler = WrapperHandler.build_wrappers(wrappers) Toolkit.__init__( # type: ignore self, @@ -48,9 +48,9 @@ class NewsAPIsTool(NewsWrapper, Toolkit): ) def get_top_headlines(self, limit: int = 100) -> list[Article]: - return self.wrapper_handler.try_call(lambda w: w.get_top_headlines(limit)) + return self.handler.try_call(lambda w: w.get_top_headlines(limit)) def get_latest_news(self, query: str, limit: int = 100) -> list[Article]: - return self.wrapper_handler.try_call(lambda w: w.get_latest_news(query, limit)) + return self.handler.try_call(lambda w: w.get_latest_news(query, limit)) def get_top_headlines_aggregated(self, limit: int = 100) -> dict[str, list[Article]]: """ @@ -62,7 +62,7 @@ class NewsAPIsTool(NewsWrapper, Toolkit): Raises: Exception: If all wrappers fail to provide results. """ - return self.wrapper_handler.try_call_all(lambda w: w.get_top_headlines(limit)) + return self.handler.try_call_all(lambda w: w.get_top_headlines(limit)) def get_latest_news_aggregated(self, query: str, limit: int = 100) -> dict[str, list[Article]]: """ @@ -75,4 +75,4 @@ class NewsAPIsTool(NewsWrapper, Toolkit): Raises: Exception: If all wrappers fail to provide results. """ - return self.wrapper_handler.try_call_all(lambda w: w.get_latest_news(query, limit)) + return self.handler.try_call_all(lambda w: w.get_latest_news(query, limit)) diff --git a/src/app/api/social/__init__.py b/src/app/api/social/__init__.py index 0e84809..69d4331 100644 --- a/src/app/api/social/__init__.py +++ b/src/app/api/social/__init__.py @@ -1,5 +1,5 @@ from agno.tools import Toolkit -from app.utils import WrapperHandler +from app.api.wrapper_handler import WrapperHandler from app.api.base.social import SocialPost, SocialWrapper from app.api.social.reddit import RedditWrapper @@ -26,7 +26,7 @@ class SocialAPIsTool(SocialWrapper, Toolkit): """ wrappers: list[type[SocialWrapper]] = [RedditWrapper] - self.wrapper_handler = WrapperHandler.build_wrappers(wrappers) + self.handler = WrapperHandler.build_wrappers(wrappers) Toolkit.__init__( # type: ignore self, @@ -38,7 +38,7 @@ class SocialAPIsTool(SocialWrapper, Toolkit): ) def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]: - return self.wrapper_handler.try_call(lambda w: w.get_top_crypto_posts(limit)) + return self.handler.try_call(lambda w: w.get_top_crypto_posts(limit)) def get_top_crypto_posts_aggregated(self, limit_per_wrapper: int = 5) -> dict[str, list[SocialPost]]: """ @@ -50,4 +50,4 @@ class SocialAPIsTool(SocialWrapper, Toolkit): Raises: Exception: If all wrappers fail to provide results. """ - return self.wrapper_handler.try_call_all(lambda w: w.get_top_crypto_posts(limit_per_wrapper)) + return self.handler.try_call_all(lambda w: w.get_top_crypto_posts(limit_per_wrapper)) diff --git a/src/app/utils/wrapper_handler.py b/src/app/api/wrapper_handler.py similarity index 100% rename from src/app/utils/wrapper_handler.py rename to src/app/api/wrapper_handler.py diff --git a/src/app/utils/__init__.py b/src/app/utils/__init__.py index 1a511c1..579b141 100644 --- a/src/app/utils/__init__.py +++ b/src/app/utils/__init__.py @@ -1,5 +1,3 @@ -from app.utils.market_aggregation import aggregate_history_prices, aggregate_product_info -from app.utils.wrapper_handler import WrapperHandler from app.utils.chat_manager import ChatManager -__all__ = ["aggregate_history_prices", "aggregate_product_info", "WrapperHandler", "ChatManager"] +__all__ = ["ChatManager"] diff --git a/src/app/utils/market_aggregation.py b/src/app/utils/market_aggregation.py deleted file mode 100644 index 2915c30..0000000 --- a/src/app/utils/market_aggregation.py +++ /dev/null @@ -1,65 +0,0 @@ -import statistics -from app.api.base.markets import ProductInfo, Price - - -def aggregate_history_prices(prices: dict[str, list[Price]]) -> list[Price]: - """ - Aggrega i prezzi storici per symbol calcolando la media. - Args: - prices (dict[str, list[Price]]): Mappa provider -> lista di Price - Returns: - list[Price]: Lista di Price aggregati per timestamp - """ - - # Costruiamo una mappa timestamp -> lista di Price - timestamped_prices: dict[str, list[Price]] = {} - for _, price_list in prices.items(): - for price in price_list: - timestamped_prices.setdefault(price.timestamp, []).append(price) - - # Ora aggregiamo i prezzi per ogni timestamp - aggregated_prices: list[Price] = [] - for time, price_list in timestamped_prices.items(): - price = Price() - price.timestamp = time - price.high = statistics.mean([p.high for p in price_list]) - price.low = statistics.mean([p.low for p in price_list]) - price.open = statistics.mean([p.open for p in price_list]) - price.close = statistics.mean([p.close for p in price_list]) - price.volume = statistics.mean([p.volume for p in price_list]) - aggregated_prices.append(price) - return aggregated_prices - -def aggregate_product_info(products: dict[str, list[ProductInfo]]) -> list[ProductInfo]: - """ - Aggrega una lista di ProductInfo per symbol. - Args: - products (dict[str, list[ProductInfo]]): Mappa provider -> lista di ProductInfo - Returns: - list[ProductInfo]: Lista di ProductInfo aggregati per symbol - """ - - # Costruzione mappa symbol -> lista di ProductInfo - symbols_infos: dict[str, list[ProductInfo]] = {} - for _, product_list in products.items(): - for product in product_list: - symbols_infos.setdefault(product.symbol, []).append(product) - - # Aggregazione per ogni symbol - aggregated_products: list[ProductInfo] = [] - for symbol, product_list in symbols_infos.items(): - product = ProductInfo() - - product.id = f"{symbol}_AGGREGATED" - product.symbol = symbol - product.currency = next(p.currency for p in product_list if p.currency) - - volume_sum = sum(p.volume_24h for p in product_list) - product.volume_24h = volume_sum / len(product_list) if product_list else 0.0 - - prices = sum(p.price * p.volume_24h for p in product_list) - product.price = (prices / volume_sum) if volume_sum > 0 else 0.0 - - aggregated_products.append(product) - return aggregated_products - diff --git a/tests/tools/test_news_tool.py b/tests/tools/test_news_tool.py index 8193429..5f685a8 100644 --- a/tests/tools/test_news_tool.py +++ b/tests/tools/test_news_tool.py @@ -12,7 +12,7 @@ class TestNewsAPITool: def test_news_api_tool_get_top(self): tool = NewsAPIsTool() - result = tool.wrapper_handler.try_call(lambda w: w.get_top_headlines(limit=2)) + result = tool.handler.try_call(lambda w: w.get_top_headlines(limit=2)) assert isinstance(result, list) assert len(result) > 0 for article in result: @@ -21,7 +21,7 @@ class TestNewsAPITool: def test_news_api_tool_get_latest(self): tool = NewsAPIsTool() - result = tool.wrapper_handler.try_call(lambda w: w.get_latest_news(query="crypto", limit=2)) + result = tool.handler.try_call(lambda w: w.get_latest_news(query="crypto", limit=2)) assert isinstance(result, list) assert len(result) > 0 for article in result: @@ -30,7 +30,7 @@ class TestNewsAPITool: def test_news_api_tool_get_top__all_results(self): tool = NewsAPIsTool() - result = tool.wrapper_handler.try_call_all(lambda w: w.get_top_headlines(limit=2)) + result = tool.handler.try_call_all(lambda w: w.get_top_headlines(limit=2)) assert isinstance(result, dict) assert len(result.keys()) > 0 for _provider, articles in result.items(): @@ -40,7 +40,7 @@ class TestNewsAPITool: def test_news_api_tool_get_latest__all_results(self): tool = NewsAPIsTool() - result = tool.wrapper_handler.try_call_all(lambda w: w.get_latest_news(query="crypto", limit=2)) + result = tool.handler.try_call_all(lambda w: w.get_latest_news(query="crypto", limit=2)) assert isinstance(result, dict) assert len(result.keys()) > 0 for _provider, articles in result.items(): diff --git a/tests/tools/test_socials_tool.py b/tests/tools/test_socials_tool.py index b1b8fd8..29a81ae 100644 --- a/tests/tools/test_socials_tool.py +++ b/tests/tools/test_socials_tool.py @@ -12,7 +12,7 @@ class TestSocialAPIsTool: def test_social_api_tool_get_top(self): tool = SocialAPIsTool() - result = tool.wrapper_handler.try_call(lambda w: w.get_top_crypto_posts(limit=2)) + result = tool.handler.try_call(lambda w: w.get_top_crypto_posts(limit=2)) assert isinstance(result, list) assert len(result) > 0 for post in result: @@ -21,7 +21,7 @@ class TestSocialAPIsTool: def test_social_api_tool_get_top__all_results(self): tool = SocialAPIsTool() - result = tool.wrapper_handler.try_call_all(lambda w: w.get_top_crypto_posts(limit=2)) + result = tool.handler.try_call_all(lambda w: w.get_top_crypto_posts(limit=2)) assert isinstance(result, dict) assert len(result.keys()) > 0 for _provider, posts in result.items(): diff --git a/tests/utils/test_market_aggregator.py b/tests/utils/test_market_aggregator.py index 02e287f..fb789b3 100644 --- a/tests/utils/test_market_aggregator.py +++ b/tests/utils/test_market_aggregator.py @@ -1,7 +1,6 @@ import pytest from datetime import datetime -from app.api.base.markets import ProductInfo, Price -from app.utils.market_aggregation import aggregate_history_prices, aggregate_product_info +from app.api.base.markets import ProductInfo, Price, aggregate_history_prices, aggregate_product_info @pytest.mark.aggregator diff --git a/tests/utils/test_wrapper_handler.py b/tests/utils/test_wrapper_handler.py index c6094a1..86922ab 100644 --- a/tests/utils/test_wrapper_handler.py +++ b/tests/utils/test_wrapper_handler.py @@ -1,5 +1,5 @@ import pytest -from app.utils.wrapper_handler import WrapperHandler +from app.api.wrapper_handler import WrapperHandler class MockWrapper: def do_something(self) -> str: