Rinominato BaseWrapper in MarketWrapper e fix type check markets

This commit is contained in:
2025-10-04 19:11:47 +02:00
parent 07ab380669
commit 3a6702642b
7 changed files with 53 additions and 57 deletions

View File

@@ -1,5 +1,5 @@
from agno.tools import Toolkit
from app.markets.base import BaseWrapper, Price, ProductInfo
from app.markets.base import MarketWrapper, Price, ProductInfo
from app.markets.binance import BinanceWrapper
from app.markets.coinbase import CoinBaseWrapper
from app.markets.cryptocompare import CryptoCompareWrapper
@@ -10,7 +10,7 @@ from app.utils.wrapper_handler import WrapperHandler
__all__ = [ "MarketAPIsTool", "BinanceWrapper", "CoinBaseWrapper", "CryptoCompareWrapper", "YFinanceWrapper", "ProductInfo", "Price" ]
class MarketAPIsTool(BaseWrapper, Toolkit):
class MarketAPIsTool(MarketWrapper, Toolkit):
"""
Class that aggregates multiple market API wrappers and manages them using WrapperHandler.
This class supports retrieving product information and historical prices.
@@ -34,10 +34,10 @@ class MarketAPIsTool(BaseWrapper, Toolkit):
currency (str): Valuta in cui restituire i prezzi. Default è "USD".
"""
kwargs = {"currency": currency or "USD"}
wrappers = [ BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper ]
self.wrappers: WrapperHandler[BaseWrapper] = WrapperHandler.build_wrappers(wrappers, kwargs=kwargs)
wrappers: list[type[MarketWrapper]] = [BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper]
self.wrappers = WrapperHandler.build_wrappers(wrappers, kwargs=kwargs)
Toolkit.__init__(
Toolkit.__init__( # type: ignore
self,
name="Market APIs Toolkit",
tools=[
@@ -53,7 +53,7 @@ class MarketAPIsTool(BaseWrapper, Toolkit):
return self.wrappers.try_call(lambda w: w.get_product(asset_id))
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
return self.wrappers.try_call(lambda w: w.get_products(asset_ids))
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
return self.wrappers.try_call(lambda w: w.get_historical_prices(asset_id, limit))

View File

@@ -24,7 +24,7 @@ class Price(BaseModel):
volume: float = 0.0
timestamp_ms: int = 0 # Timestamp in milliseconds
class BaseWrapper:
class MarketWrapper:
"""
Base class for market API wrappers.
All market API wrappers should inherit from this class and implement the methods.
@@ -50,7 +50,7 @@ class BaseWrapper:
"""
raise NotImplementedError("This method should be overridden by subclasses")
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
"""
Get historical price data for a specific asset ID.
Args:

View File

@@ -1,18 +1,19 @@
import os
from binance.client import Client
from app.markets.base import ProductInfo, BaseWrapper, Price
from typing import Any
from binance.client import Client # type: ignore
from app.markets.base import ProductInfo, MarketWrapper, Price
def extract_product(currency: str, ticker_data: dict[str, str]) -> ProductInfo:
def extract_product(currency: str, ticker_data: dict[str, Any]) -> ProductInfo:
product = ProductInfo()
product.id = ticker_data.get('symbol')
product.id = ticker_data.get('symbol', '')
product.symbol = ticker_data.get('symbol', '').replace(currency, '')
product.price = float(ticker_data.get('price', 0))
product.volume_24h = float(ticker_data.get('volume', 0))
product.quote_currency = currency
return product
def extract_price(kline_data: list) -> Price:
def extract_price(kline_data: list[Any]) -> Price:
price = Price()
price.open = float(kline_data[1])
price.high = float(kline_data[2])
@@ -22,7 +23,7 @@ def extract_price(kline_data: list) -> Price:
price.timestamp_ms = kline_data[0]
return price
class BinanceWrapper(BaseWrapper):
class BinanceWrapper(MarketWrapper):
"""
Wrapper per le API autenticate di Binance.\n
Implementa l'interfaccia BaseWrapper per fornire accesso unificato
@@ -46,31 +47,22 @@ class BinanceWrapper(BaseWrapper):
def get_product(self, asset_id: str) -> ProductInfo:
symbol = self.__format_symbol(asset_id)
ticker = self.client.get_symbol_ticker(symbol=symbol)
ticker_24h = self.client.get_ticker(symbol=symbol)
ticker['volume'] = ticker_24h.get('volume', 0) # Aggiunge volume 24h ai dati del ticker
ticker: dict[str, Any] = self.client.get_symbol_ticker(symbol=symbol) # type: ignore
ticker_24h: dict[str, Any] = self.client.get_ticker(symbol=symbol) # type: ignore
ticker['volume'] = ticker_24h.get('volume', 0)
return extract_product(self.currency, ticker)
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
symbols = [self.__format_symbol(asset_id) for asset_id in asset_ids]
symbols_str = f"[\"{'","'.join(symbols)}\"]"
return [ self.get_product(asset_id) for asset_id in asset_ids ]
tickers = self.client.get_symbol_ticker(symbols=symbols_str)
tickers_24h = self.client.get_ticker(symbols=symbols_str) # un po brutale, ma va bene così
for t, t24 in zip(tickers, tickers_24h):
t['volume'] = t24.get('volume', 0)
return [extract_product(self.currency, ticker) for ticker in tickers]
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
symbol = self.__format_symbol(asset_id)
# Ottiene candele orarie degli ultimi 30 giorni
klines = self.client.get_historical_klines(
klines: list[list[Any]] = self.client.get_historical_klines( # type: ignore
symbol=symbol,
interval=Client.KLINE_INTERVAL_1HOUR,
limit=limit,
)
return [extract_price(kline) for kline in klines]

View File

@@ -1,9 +1,9 @@
import os
from enum import Enum
from datetime import datetime, timedelta
from coinbase.rest import RESTClient
from coinbase.rest.types.product_types import Candle, GetProductResponse, Product
from app.markets.base import ProductInfo, BaseWrapper, Price
from coinbase.rest import RESTClient # type: ignore
from coinbase.rest.types.product_types import Candle, GetProductResponse, Product # type: ignore
from app.markets.base import ProductInfo, MarketWrapper, Price
def extract_product(product_data: GetProductResponse | Product) -> ProductInfo:
@@ -37,7 +37,7 @@ class Granularity(Enum):
SIX_HOUR = 21600
ONE_DAY = 86400
class CoinBaseWrapper(BaseWrapper):
class CoinBaseWrapper(MarketWrapper):
"""
Wrapper per le API di Coinbase Advanced Trade.\n
Implementa l'interfaccia BaseWrapper per fornire accesso unificato
@@ -63,24 +63,26 @@ class CoinBaseWrapper(BaseWrapper):
def get_product(self, asset_id: str) -> ProductInfo:
asset_id = self.__format(asset_id)
asset = self.client.get_product(asset_id)
asset = self.client.get_product(asset_id) # type: ignore
return extract_product(asset)
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
all_asset_ids = [self.__format(asset_id) for asset_id in asset_ids]
assets = self.client.get_products(product_ids=all_asset_ids)
assets = self.client.get_products(product_ids=all_asset_ids) # type: ignore
assert assets.products is not None, "No products data received from Coinbase"
return [extract_product(asset) for asset in assets.products]
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
asset_id = self.__format(asset_id)
end_time = datetime.now()
start_time = end_time - timedelta(days=14)
data = self.client.get_candles(
data = self.client.get_candles( # type: ignore
product_id=asset_id,
granularity=Granularity.ONE_HOUR.name,
start=str(int(start_time.timestamp())),
end=str(int(end_time.timestamp())),
limit=limit
)
assert data.candles is not None, "No candles data received from Coinbase"
return [extract_price(candle) for candle in data.candles]

View File

@@ -1,9 +1,10 @@
import os
from typing import Any
import requests
from app.markets.base import ProductInfo, BaseWrapper, Price
from app.markets.base import ProductInfo, MarketWrapper, Price
def extract_product(asset_data: dict) -> ProductInfo:
def extract_product(asset_data: dict[str, Any]) -> ProductInfo:
product = ProductInfo()
product.id = asset_data.get('FROMSYMBOL', '') + '-' + asset_data.get('TOSYMBOL', '')
product.symbol = asset_data.get('FROMSYMBOL', '')
@@ -12,7 +13,7 @@ def extract_product(asset_data: dict) -> ProductInfo:
assert product.price > 0, "Invalid price data received from CryptoCompare"
return product
def extract_price(price_data: dict) -> Price:
def extract_price(price_data: dict[str, Any]) -> Price:
price = Price()
price.high = float(price_data.get('high', 0))
price.low = float(price_data.get('low', 0))
@@ -26,7 +27,7 @@ def extract_price(price_data: dict) -> Price:
BASE_URL = "https://min-api.cryptocompare.com"
class CryptoCompareWrapper(BaseWrapper):
class CryptoCompareWrapper(MarketWrapper):
"""
Wrapper per le API pubbliche di CryptoCompare.
La documentazione delle API è disponibile qui: https://developers.coindesk.com/documentation/legacy/Price/SingleSymbolPriceEndpoint
@@ -39,7 +40,7 @@ class CryptoCompareWrapper(BaseWrapper):
self.api_key = api_key
self.currency = currency
def __request(self, endpoint: str, params: dict[str, str] | None = None) -> dict[str, str]:
def __request(self, endpoint: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
if params is None:
params = {}
params['api_key'] = self.api_key
@@ -60,7 +61,7 @@ class CryptoCompareWrapper(BaseWrapper):
"fsyms": ",".join(asset_ids),
"tsyms": self.currency
})
assets = []
assets: list[ProductInfo] = []
data = response.get('RAW', {})
for asset_id in asset_ids:
asset_data = data.get(asset_id, {}).get(self.currency, {})

View File

@@ -1,6 +1,6 @@
import json
from agno.tools.yfinance import YFinanceTools
from app.markets.base import BaseWrapper, ProductInfo, Price
from app.markets.base import MarketWrapper, ProductInfo, Price
def extract_product(stock_data: dict[str, str]) -> ProductInfo:
@@ -29,7 +29,7 @@ def extract_price(hist_data: dict[str, str]) -> Price:
return price
class YFinanceWrapper(BaseWrapper):
class YFinanceWrapper(MarketWrapper):
"""
Wrapper per YFinanceTools che fornisce dati di mercato per azioni, ETF e criptovalute.
Implementa l'interfaccia BaseWrapper per compatibilità con il sistema esistente.
@@ -55,13 +55,13 @@ class YFinanceWrapper(BaseWrapper):
return extract_product(stock_info)
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
products = []
products: list[ProductInfo] = []
for asset_id in asset_ids:
product = self.get_product(asset_id)
products.append(product)
return products
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
symbol = self._format_symbol(asset_id)
days = limit // 24 + 1 # Arrotonda per eccesso
@@ -71,7 +71,7 @@ class YFinanceWrapper(BaseWrapper):
# Il formato dei dati è {timestamp: {Open: x, High: y, Low: z, Close: w, Volume: v}}
timestamps = sorted(hist_data.keys())[-limit:]
prices = []
prices: list[Price] = []
for timestamp in timestamps:
temp = hist_data[timestamp]
temp['Timestamp'] = timestamp

View File

@@ -4,11 +4,12 @@ import traceback
from typing import Any, Callable, Generic, TypeVar
from agno.utils.log import log_info, log_warning #type: ignore
W = TypeVar("W")
T = TypeVar("T")
WrapperType = TypeVar("WrapperType")
WrapperClassType = TypeVar("WrapperClassType")
OutputType = TypeVar("OutputType")
class WrapperHandler(Generic[W]):
class WrapperHandler(Generic[WrapperType]):
"""
A handler for managing multiple wrappers with retry logic.
It attempts to call a function on the current wrapper, and if it fails,
@@ -18,7 +19,7 @@ class WrapperHandler(Generic[W]):
Note: use `build_wrappers` to create an instance of this class for better error handling.
"""
def __init__(self, wrappers: list[W], try_per_wrapper: int = 3, retry_delay: int = 2):
def __init__(self, wrappers: list[WrapperType], try_per_wrapper: int = 3, retry_delay: int = 2):
"""
Initializes the WrapperHandler with a list of wrappers and retry settings.\n
Use `build_wrappers` to create an instance of this class for better error handling.
@@ -35,7 +36,7 @@ class WrapperHandler(Generic[W]):
self.index = 0
self.retry_count = 0
def try_call(self, func: Callable[[W], T]) -> T:
def try_call(self, func: Callable[[WrapperType], OutputType]) -> OutputType:
"""
Attempts to call the provided function on the current wrapper.
If it fails, it retries a specified number of times before switching to the next wrapper.
@@ -76,7 +77,7 @@ class WrapperHandler(Generic[W]):
raise Exception(f"All wrappers failed, latest error: {error}")
def try_call_all(self, func: Callable[[W], T]) -> dict[str, T]:
def try_call_all(self, func: Callable[[WrapperType], OutputType]) -> dict[str, OutputType]:
"""
Calls the provided function on all wrappers, collecting results.
If a wrapper fails, it logs a warning and continues with the next.
@@ -90,7 +91,7 @@ class WrapperHandler(Generic[W]):
"""
log_info(f"{inspect.getsource(func).strip()} {inspect.getclosurevars(func).nonlocals}")
results: dict[str, T] = {}
results: dict[str, OutputType] = {}
error = ""
for wrapper in self.wrappers:
wrapper_name = wrapper.__class__.__name__
@@ -115,7 +116,7 @@ class WrapperHandler(Generic[W]):
return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]"
@staticmethod
def build_wrappers(constructors: list[type[W]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[W]':
def build_wrappers(constructors: list[type[WrapperClassType]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[WrapperClassType]':
"""
Builds a WrapperHandler instance with the given wrapper constructors.
It attempts to initialize each wrapper and logs a warning if any cannot be initialized.
@@ -132,7 +133,7 @@ class WrapperHandler(Generic[W]):
"""
assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}"
result: list[W] = []
result: list[WrapperClassType] = []
for wrapper_class in constructors:
try:
wrapper = wrapper_class(**(kwargs or {}))