Rinominato BaseWrapper in MarketWrapper e fix type check markets

This commit is contained in:
2025-10-04 19:11:47 +02:00
parent 07ab380669
commit 3a6702642b
7 changed files with 53 additions and 57 deletions

View File

@@ -1,5 +1,5 @@
from agno.tools import Toolkit from agno.tools import Toolkit
from app.markets.base import BaseWrapper, Price, ProductInfo from app.markets.base import MarketWrapper, Price, ProductInfo
from app.markets.binance import BinanceWrapper from app.markets.binance import BinanceWrapper
from app.markets.coinbase import CoinBaseWrapper from app.markets.coinbase import CoinBaseWrapper
from app.markets.cryptocompare import CryptoCompareWrapper from app.markets.cryptocompare import CryptoCompareWrapper
@@ -10,7 +10,7 @@ from app.utils.wrapper_handler import WrapperHandler
__all__ = [ "MarketAPIsTool", "BinanceWrapper", "CoinBaseWrapper", "CryptoCompareWrapper", "YFinanceWrapper", "ProductInfo", "Price" ] __all__ = [ "MarketAPIsTool", "BinanceWrapper", "CoinBaseWrapper", "CryptoCompareWrapper", "YFinanceWrapper", "ProductInfo", "Price" ]
class MarketAPIsTool(BaseWrapper, Toolkit): class MarketAPIsTool(MarketWrapper, Toolkit):
""" """
Class that aggregates multiple market API wrappers and manages them using WrapperHandler. Class that aggregates multiple market API wrappers and manages them using WrapperHandler.
This class supports retrieving product information and historical prices. This class supports retrieving product information and historical prices.
@@ -34,10 +34,10 @@ class MarketAPIsTool(BaseWrapper, Toolkit):
currency (str): Valuta in cui restituire i prezzi. Default è "USD". currency (str): Valuta in cui restituire i prezzi. Default è "USD".
""" """
kwargs = {"currency": currency or "USD"} kwargs = {"currency": currency or "USD"}
wrappers = [ BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper ] wrappers: list[type[MarketWrapper]] = [BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper]
self.wrappers: WrapperHandler[BaseWrapper] = WrapperHandler.build_wrappers(wrappers, kwargs=kwargs) self.wrappers = WrapperHandler.build_wrappers(wrappers, kwargs=kwargs)
Toolkit.__init__( Toolkit.__init__( # type: ignore
self, self,
name="Market APIs Toolkit", name="Market APIs Toolkit",
tools=[ tools=[
@@ -53,7 +53,7 @@ class MarketAPIsTool(BaseWrapper, Toolkit):
return self.wrappers.try_call(lambda w: w.get_product(asset_id)) return self.wrappers.try_call(lambda w: w.get_product(asset_id))
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
return self.wrappers.try_call(lambda w: w.get_products(asset_ids)) return self.wrappers.try_call(lambda w: w.get_products(asset_ids))
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
return self.wrappers.try_call(lambda w: w.get_historical_prices(asset_id, limit)) return self.wrappers.try_call(lambda w: w.get_historical_prices(asset_id, limit))

View File

@@ -24,7 +24,7 @@ class Price(BaseModel):
volume: float = 0.0 volume: float = 0.0
timestamp_ms: int = 0 # Timestamp in milliseconds timestamp_ms: int = 0 # Timestamp in milliseconds
class BaseWrapper: class MarketWrapper:
""" """
Base class for market API wrappers. Base class for market API wrappers.
All market API wrappers should inherit from this class and implement the methods. All market API wrappers should inherit from this class and implement the methods.
@@ -50,7 +50,7 @@ class BaseWrapper:
""" """
raise NotImplementedError("This method should be overridden by subclasses") raise NotImplementedError("This method should be overridden by subclasses")
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
""" """
Get historical price data for a specific asset ID. Get historical price data for a specific asset ID.
Args: Args:

View File

@@ -1,18 +1,19 @@
import os import os
from binance.client import Client from typing import Any
from app.markets.base import ProductInfo, BaseWrapper, Price from binance.client import Client # type: ignore
from app.markets.base import ProductInfo, MarketWrapper, Price
def extract_product(currency: str, ticker_data: dict[str, str]) -> ProductInfo: def extract_product(currency: str, ticker_data: dict[str, Any]) -> ProductInfo:
product = ProductInfo() product = ProductInfo()
product.id = ticker_data.get('symbol') product.id = ticker_data.get('symbol', '')
product.symbol = ticker_data.get('symbol', '').replace(currency, '') product.symbol = ticker_data.get('symbol', '').replace(currency, '')
product.price = float(ticker_data.get('price', 0)) product.price = float(ticker_data.get('price', 0))
product.volume_24h = float(ticker_data.get('volume', 0)) product.volume_24h = float(ticker_data.get('volume', 0))
product.quote_currency = currency product.quote_currency = currency
return product return product
def extract_price(kline_data: list) -> Price: def extract_price(kline_data: list[Any]) -> Price:
price = Price() price = Price()
price.open = float(kline_data[1]) price.open = float(kline_data[1])
price.high = float(kline_data[2]) price.high = float(kline_data[2])
@@ -22,7 +23,7 @@ def extract_price(kline_data: list) -> Price:
price.timestamp_ms = kline_data[0] price.timestamp_ms = kline_data[0]
return price return price
class BinanceWrapper(BaseWrapper): class BinanceWrapper(MarketWrapper):
""" """
Wrapper per le API autenticate di Binance.\n Wrapper per le API autenticate di Binance.\n
Implementa l'interfaccia BaseWrapper per fornire accesso unificato Implementa l'interfaccia BaseWrapper per fornire accesso unificato
@@ -46,31 +47,22 @@ class BinanceWrapper(BaseWrapper):
def get_product(self, asset_id: str) -> ProductInfo: def get_product(self, asset_id: str) -> ProductInfo:
symbol = self.__format_symbol(asset_id) symbol = self.__format_symbol(asset_id)
ticker = self.client.get_symbol_ticker(symbol=symbol) ticker: dict[str, Any] = self.client.get_symbol_ticker(symbol=symbol) # type: ignore
ticker_24h = self.client.get_ticker(symbol=symbol) ticker_24h: dict[str, Any] = self.client.get_ticker(symbol=symbol) # type: ignore
ticker['volume'] = ticker_24h.get('volume', 0) # Aggiunge volume 24h ai dati del ticker ticker['volume'] = ticker_24h.get('volume', 0)
return extract_product(self.currency, ticker) return extract_product(self.currency, ticker)
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
symbols = [self.__format_symbol(asset_id) for asset_id in asset_ids] return [ self.get_product(asset_id) for asset_id in asset_ids ]
symbols_str = f"[\"{'","'.join(symbols)}\"]"
tickers = self.client.get_symbol_ticker(symbols=symbols_str) def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
tickers_24h = self.client.get_ticker(symbols=symbols_str) # un po brutale, ma va bene così
for t, t24 in zip(tickers, tickers_24h):
t['volume'] = t24.get('volume', 0)
return [extract_product(self.currency, ticker) for ticker in tickers]
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
symbol = self.__format_symbol(asset_id) symbol = self.__format_symbol(asset_id)
# Ottiene candele orarie degli ultimi 30 giorni # Ottiene candele orarie degli ultimi 30 giorni
klines = self.client.get_historical_klines( klines: list[list[Any]] = self.client.get_historical_klines( # type: ignore
symbol=symbol, symbol=symbol,
interval=Client.KLINE_INTERVAL_1HOUR, interval=Client.KLINE_INTERVAL_1HOUR,
limit=limit, limit=limit,
) )
return [extract_price(kline) for kline in klines] return [extract_price(kline) for kline in klines]

View File

@@ -1,9 +1,9 @@
import os import os
from enum import Enum from enum import Enum
from datetime import datetime, timedelta from datetime import datetime, timedelta
from coinbase.rest import RESTClient from coinbase.rest import RESTClient # type: ignore
from coinbase.rest.types.product_types import Candle, GetProductResponse, Product from coinbase.rest.types.product_types import Candle, GetProductResponse, Product # type: ignore
from app.markets.base import ProductInfo, BaseWrapper, Price from app.markets.base import ProductInfo, MarketWrapper, Price
def extract_product(product_data: GetProductResponse | Product) -> ProductInfo: def extract_product(product_data: GetProductResponse | Product) -> ProductInfo:
@@ -37,7 +37,7 @@ class Granularity(Enum):
SIX_HOUR = 21600 SIX_HOUR = 21600
ONE_DAY = 86400 ONE_DAY = 86400
class CoinBaseWrapper(BaseWrapper): class CoinBaseWrapper(MarketWrapper):
""" """
Wrapper per le API di Coinbase Advanced Trade.\n Wrapper per le API di Coinbase Advanced Trade.\n
Implementa l'interfaccia BaseWrapper per fornire accesso unificato Implementa l'interfaccia BaseWrapper per fornire accesso unificato
@@ -63,24 +63,26 @@ class CoinBaseWrapper(BaseWrapper):
def get_product(self, asset_id: str) -> ProductInfo: def get_product(self, asset_id: str) -> ProductInfo:
asset_id = self.__format(asset_id) asset_id = self.__format(asset_id)
asset = self.client.get_product(asset_id) asset = self.client.get_product(asset_id) # type: ignore
return extract_product(asset) return extract_product(asset)
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
all_asset_ids = [self.__format(asset_id) for asset_id in asset_ids] all_asset_ids = [self.__format(asset_id) for asset_id in asset_ids]
assets = self.client.get_products(product_ids=all_asset_ids) assets = self.client.get_products(product_ids=all_asset_ids) # type: ignore
assert assets.products is not None, "No products data received from Coinbase"
return [extract_product(asset) for asset in assets.products] return [extract_product(asset) for asset in assets.products]
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
asset_id = self.__format(asset_id) asset_id = self.__format(asset_id)
end_time = datetime.now() end_time = datetime.now()
start_time = end_time - timedelta(days=14) start_time = end_time - timedelta(days=14)
data = self.client.get_candles( data = self.client.get_candles( # type: ignore
product_id=asset_id, product_id=asset_id,
granularity=Granularity.ONE_HOUR.name, granularity=Granularity.ONE_HOUR.name,
start=str(int(start_time.timestamp())), start=str(int(start_time.timestamp())),
end=str(int(end_time.timestamp())), end=str(int(end_time.timestamp())),
limit=limit limit=limit
) )
assert data.candles is not None, "No candles data received from Coinbase"
return [extract_price(candle) for candle in data.candles] return [extract_price(candle) for candle in data.candles]

View File

@@ -1,9 +1,10 @@
import os import os
from typing import Any
import requests import requests
from app.markets.base import ProductInfo, BaseWrapper, Price from app.markets.base import ProductInfo, MarketWrapper, Price
def extract_product(asset_data: dict) -> ProductInfo: def extract_product(asset_data: dict[str, Any]) -> ProductInfo:
product = ProductInfo() product = ProductInfo()
product.id = asset_data.get('FROMSYMBOL', '') + '-' + asset_data.get('TOSYMBOL', '') product.id = asset_data.get('FROMSYMBOL', '') + '-' + asset_data.get('TOSYMBOL', '')
product.symbol = asset_data.get('FROMSYMBOL', '') product.symbol = asset_data.get('FROMSYMBOL', '')
@@ -12,7 +13,7 @@ def extract_product(asset_data: dict) -> ProductInfo:
assert product.price > 0, "Invalid price data received from CryptoCompare" assert product.price > 0, "Invalid price data received from CryptoCompare"
return product return product
def extract_price(price_data: dict) -> Price: def extract_price(price_data: dict[str, Any]) -> Price:
price = Price() price = Price()
price.high = float(price_data.get('high', 0)) price.high = float(price_data.get('high', 0))
price.low = float(price_data.get('low', 0)) price.low = float(price_data.get('low', 0))
@@ -26,7 +27,7 @@ def extract_price(price_data: dict) -> Price:
BASE_URL = "https://min-api.cryptocompare.com" BASE_URL = "https://min-api.cryptocompare.com"
class CryptoCompareWrapper(BaseWrapper): class CryptoCompareWrapper(MarketWrapper):
""" """
Wrapper per le API pubbliche di CryptoCompare. Wrapper per le API pubbliche di CryptoCompare.
La documentazione delle API è disponibile qui: https://developers.coindesk.com/documentation/legacy/Price/SingleSymbolPriceEndpoint La documentazione delle API è disponibile qui: https://developers.coindesk.com/documentation/legacy/Price/SingleSymbolPriceEndpoint
@@ -39,7 +40,7 @@ class CryptoCompareWrapper(BaseWrapper):
self.api_key = api_key self.api_key = api_key
self.currency = currency self.currency = currency
def __request(self, endpoint: str, params: dict[str, str] | None = None) -> dict[str, str]: def __request(self, endpoint: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
if params is None: if params is None:
params = {} params = {}
params['api_key'] = self.api_key params['api_key'] = self.api_key
@@ -60,7 +61,7 @@ class CryptoCompareWrapper(BaseWrapper):
"fsyms": ",".join(asset_ids), "fsyms": ",".join(asset_ids),
"tsyms": self.currency "tsyms": self.currency
}) })
assets = [] assets: list[ProductInfo] = []
data = response.get('RAW', {}) data = response.get('RAW', {})
for asset_id in asset_ids: for asset_id in asset_ids:
asset_data = data.get(asset_id, {}).get(self.currency, {}) asset_data = data.get(asset_id, {}).get(self.currency, {})

View File

@@ -1,6 +1,6 @@
import json import json
from agno.tools.yfinance import YFinanceTools from agno.tools.yfinance import YFinanceTools
from app.markets.base import BaseWrapper, ProductInfo, Price from app.markets.base import MarketWrapper, ProductInfo, Price
def extract_product(stock_data: dict[str, str]) -> ProductInfo: def extract_product(stock_data: dict[str, str]) -> ProductInfo:
@@ -29,7 +29,7 @@ def extract_price(hist_data: dict[str, str]) -> Price:
return price return price
class YFinanceWrapper(BaseWrapper): class YFinanceWrapper(MarketWrapper):
""" """
Wrapper per YFinanceTools che fornisce dati di mercato per azioni, ETF e criptovalute. Wrapper per YFinanceTools che fornisce dati di mercato per azioni, ETF e criptovalute.
Implementa l'interfaccia BaseWrapper per compatibilità con il sistema esistente. Implementa l'interfaccia BaseWrapper per compatibilità con il sistema esistente.
@@ -55,13 +55,13 @@ class YFinanceWrapper(BaseWrapper):
return extract_product(stock_info) return extract_product(stock_info)
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]: def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
products = [] products: list[ProductInfo] = []
for asset_id in asset_ids: for asset_id in asset_ids:
product = self.get_product(asset_id) product = self.get_product(asset_id)
products.append(product) products.append(product)
return products return products
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]: def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
symbol = self._format_symbol(asset_id) symbol = self._format_symbol(asset_id)
days = limit // 24 + 1 # Arrotonda per eccesso days = limit // 24 + 1 # Arrotonda per eccesso
@@ -71,7 +71,7 @@ class YFinanceWrapper(BaseWrapper):
# Il formato dei dati è {timestamp: {Open: x, High: y, Low: z, Close: w, Volume: v}} # Il formato dei dati è {timestamp: {Open: x, High: y, Low: z, Close: w, Volume: v}}
timestamps = sorted(hist_data.keys())[-limit:] timestamps = sorted(hist_data.keys())[-limit:]
prices = [] prices: list[Price] = []
for timestamp in timestamps: for timestamp in timestamps:
temp = hist_data[timestamp] temp = hist_data[timestamp]
temp['Timestamp'] = timestamp temp['Timestamp'] = timestamp

View File

@@ -4,11 +4,12 @@ import traceback
from typing import Any, Callable, Generic, TypeVar from typing import Any, Callable, Generic, TypeVar
from agno.utils.log import log_info, log_warning #type: ignore from agno.utils.log import log_info, log_warning #type: ignore
W = TypeVar("W") WrapperType = TypeVar("WrapperType")
T = TypeVar("T") WrapperClassType = TypeVar("WrapperClassType")
OutputType = TypeVar("OutputType")
class WrapperHandler(Generic[W]): class WrapperHandler(Generic[WrapperType]):
""" """
A handler for managing multiple wrappers with retry logic. A handler for managing multiple wrappers with retry logic.
It attempts to call a function on the current wrapper, and if it fails, It attempts to call a function on the current wrapper, and if it fails,
@@ -18,7 +19,7 @@ class WrapperHandler(Generic[W]):
Note: use `build_wrappers` to create an instance of this class for better error handling. Note: use `build_wrappers` to create an instance of this class for better error handling.
""" """
def __init__(self, wrappers: list[W], try_per_wrapper: int = 3, retry_delay: int = 2): def __init__(self, wrappers: list[WrapperType], try_per_wrapper: int = 3, retry_delay: int = 2):
""" """
Initializes the WrapperHandler with a list of wrappers and retry settings.\n Initializes the WrapperHandler with a list of wrappers and retry settings.\n
Use `build_wrappers` to create an instance of this class for better error handling. Use `build_wrappers` to create an instance of this class for better error handling.
@@ -35,7 +36,7 @@ class WrapperHandler(Generic[W]):
self.index = 0 self.index = 0
self.retry_count = 0 self.retry_count = 0
def try_call(self, func: Callable[[W], T]) -> T: def try_call(self, func: Callable[[WrapperType], OutputType]) -> OutputType:
""" """
Attempts to call the provided function on the current wrapper. Attempts to call the provided function on the current wrapper.
If it fails, it retries a specified number of times before switching to the next wrapper. If it fails, it retries a specified number of times before switching to the next wrapper.
@@ -76,7 +77,7 @@ class WrapperHandler(Generic[W]):
raise Exception(f"All wrappers failed, latest error: {error}") raise Exception(f"All wrappers failed, latest error: {error}")
def try_call_all(self, func: Callable[[W], T]) -> dict[str, T]: def try_call_all(self, func: Callable[[WrapperType], OutputType]) -> dict[str, OutputType]:
""" """
Calls the provided function on all wrappers, collecting results. Calls the provided function on all wrappers, collecting results.
If a wrapper fails, it logs a warning and continues with the next. If a wrapper fails, it logs a warning and continues with the next.
@@ -90,7 +91,7 @@ class WrapperHandler(Generic[W]):
""" """
log_info(f"{inspect.getsource(func).strip()} {inspect.getclosurevars(func).nonlocals}") log_info(f"{inspect.getsource(func).strip()} {inspect.getclosurevars(func).nonlocals}")
results: dict[str, T] = {} results: dict[str, OutputType] = {}
error = "" error = ""
for wrapper in self.wrappers: for wrapper in self.wrappers:
wrapper_name = wrapper.__class__.__name__ wrapper_name = wrapper.__class__.__name__
@@ -115,7 +116,7 @@ class WrapperHandler(Generic[W]):
return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]" return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]"
@staticmethod @staticmethod
def build_wrappers(constructors: list[type[W]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[W]': def build_wrappers(constructors: list[type[WrapperClassType]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[WrapperClassType]':
""" """
Builds a WrapperHandler instance with the given wrapper constructors. Builds a WrapperHandler instance with the given wrapper constructors.
It attempts to initialize each wrapper and logs a warning if any cannot be initialized. It attempts to initialize each wrapper and logs a warning if any cannot be initialized.
@@ -132,7 +133,7 @@ class WrapperHandler(Generic[W]):
""" """
assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}" assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}"
result: list[W] = [] result: list[WrapperClassType] = []
for wrapper_class in constructors: for wrapper_class in constructors:
try: try:
wrapper = wrapper_class(**(kwargs or {})) wrapper = wrapper_class(**(kwargs or {}))