Rinominato BaseWrapper in MarketWrapper e fix type check markets
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from agno.tools import Toolkit
|
from agno.tools import Toolkit
|
||||||
from app.markets.base import BaseWrapper, Price, ProductInfo
|
from app.markets.base import MarketWrapper, Price, ProductInfo
|
||||||
from app.markets.binance import BinanceWrapper
|
from app.markets.binance import BinanceWrapper
|
||||||
from app.markets.coinbase import CoinBaseWrapper
|
from app.markets.coinbase import CoinBaseWrapper
|
||||||
from app.markets.cryptocompare import CryptoCompareWrapper
|
from app.markets.cryptocompare import CryptoCompareWrapper
|
||||||
@@ -10,7 +10,7 @@ from app.utils.wrapper_handler import WrapperHandler
|
|||||||
__all__ = [ "MarketAPIsTool", "BinanceWrapper", "CoinBaseWrapper", "CryptoCompareWrapper", "YFinanceWrapper", "ProductInfo", "Price" ]
|
__all__ = [ "MarketAPIsTool", "BinanceWrapper", "CoinBaseWrapper", "CryptoCompareWrapper", "YFinanceWrapper", "ProductInfo", "Price" ]
|
||||||
|
|
||||||
|
|
||||||
class MarketAPIsTool(BaseWrapper, Toolkit):
|
class MarketAPIsTool(MarketWrapper, Toolkit):
|
||||||
"""
|
"""
|
||||||
Class that aggregates multiple market API wrappers and manages them using WrapperHandler.
|
Class that aggregates multiple market API wrappers and manages them using WrapperHandler.
|
||||||
This class supports retrieving product information and historical prices.
|
This class supports retrieving product information and historical prices.
|
||||||
@@ -34,10 +34,10 @@ class MarketAPIsTool(BaseWrapper, Toolkit):
|
|||||||
currency (str): Valuta in cui restituire i prezzi. Default è "USD".
|
currency (str): Valuta in cui restituire i prezzi. Default è "USD".
|
||||||
"""
|
"""
|
||||||
kwargs = {"currency": currency or "USD"}
|
kwargs = {"currency": currency or "USD"}
|
||||||
wrappers = [ BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper ]
|
wrappers: list[type[MarketWrapper]] = [BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper]
|
||||||
self.wrappers: WrapperHandler[BaseWrapper] = WrapperHandler.build_wrappers(wrappers, kwargs=kwargs)
|
self.wrappers = WrapperHandler.build_wrappers(wrappers, kwargs=kwargs)
|
||||||
|
|
||||||
Toolkit.__init__(
|
Toolkit.__init__( # type: ignore
|
||||||
self,
|
self,
|
||||||
name="Market APIs Toolkit",
|
name="Market APIs Toolkit",
|
||||||
tools=[
|
tools=[
|
||||||
@@ -53,7 +53,7 @@ class MarketAPIsTool(BaseWrapper, Toolkit):
|
|||||||
return self.wrappers.try_call(lambda w: w.get_product(asset_id))
|
return self.wrappers.try_call(lambda w: w.get_product(asset_id))
|
||||||
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
|
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
|
||||||
return self.wrappers.try_call(lambda w: w.get_products(asset_ids))
|
return self.wrappers.try_call(lambda w: w.get_products(asset_ids))
|
||||||
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
|
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
|
||||||
return self.wrappers.try_call(lambda w: w.get_historical_prices(asset_id, limit))
|
return self.wrappers.try_call(lambda w: w.get_historical_prices(asset_id, limit))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class Price(BaseModel):
|
|||||||
volume: float = 0.0
|
volume: float = 0.0
|
||||||
timestamp_ms: int = 0 # Timestamp in milliseconds
|
timestamp_ms: int = 0 # Timestamp in milliseconds
|
||||||
|
|
||||||
class BaseWrapper:
|
class MarketWrapper:
|
||||||
"""
|
"""
|
||||||
Base class for market API wrappers.
|
Base class for market API wrappers.
|
||||||
All market API wrappers should inherit from this class and implement the methods.
|
All market API wrappers should inherit from this class and implement the methods.
|
||||||
@@ -50,7 +50,7 @@ class BaseWrapper:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("This method should be overridden by subclasses")
|
raise NotImplementedError("This method should be overridden by subclasses")
|
||||||
|
|
||||||
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
|
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
|
||||||
"""
|
"""
|
||||||
Get historical price data for a specific asset ID.
|
Get historical price data for a specific asset ID.
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,18 +1,19 @@
|
|||||||
import os
|
import os
|
||||||
from binance.client import Client
|
from typing import Any
|
||||||
from app.markets.base import ProductInfo, BaseWrapper, Price
|
from binance.client import Client # type: ignore
|
||||||
|
from app.markets.base import ProductInfo, MarketWrapper, Price
|
||||||
|
|
||||||
|
|
||||||
def extract_product(currency: str, ticker_data: dict[str, str]) -> ProductInfo:
|
def extract_product(currency: str, ticker_data: dict[str, Any]) -> ProductInfo:
|
||||||
product = ProductInfo()
|
product = ProductInfo()
|
||||||
product.id = ticker_data.get('symbol')
|
product.id = ticker_data.get('symbol', '')
|
||||||
product.symbol = ticker_data.get('symbol', '').replace(currency, '')
|
product.symbol = ticker_data.get('symbol', '').replace(currency, '')
|
||||||
product.price = float(ticker_data.get('price', 0))
|
product.price = float(ticker_data.get('price', 0))
|
||||||
product.volume_24h = float(ticker_data.get('volume', 0))
|
product.volume_24h = float(ticker_data.get('volume', 0))
|
||||||
product.quote_currency = currency
|
product.quote_currency = currency
|
||||||
return product
|
return product
|
||||||
|
|
||||||
def extract_price(kline_data: list) -> Price:
|
def extract_price(kline_data: list[Any]) -> Price:
|
||||||
price = Price()
|
price = Price()
|
||||||
price.open = float(kline_data[1])
|
price.open = float(kline_data[1])
|
||||||
price.high = float(kline_data[2])
|
price.high = float(kline_data[2])
|
||||||
@@ -22,7 +23,7 @@ def extract_price(kline_data: list) -> Price:
|
|||||||
price.timestamp_ms = kline_data[0]
|
price.timestamp_ms = kline_data[0]
|
||||||
return price
|
return price
|
||||||
|
|
||||||
class BinanceWrapper(BaseWrapper):
|
class BinanceWrapper(MarketWrapper):
|
||||||
"""
|
"""
|
||||||
Wrapper per le API autenticate di Binance.\n
|
Wrapper per le API autenticate di Binance.\n
|
||||||
Implementa l'interfaccia BaseWrapper per fornire accesso unificato
|
Implementa l'interfaccia BaseWrapper per fornire accesso unificato
|
||||||
@@ -46,31 +47,22 @@ class BinanceWrapper(BaseWrapper):
|
|||||||
def get_product(self, asset_id: str) -> ProductInfo:
|
def get_product(self, asset_id: str) -> ProductInfo:
|
||||||
symbol = self.__format_symbol(asset_id)
|
symbol = self.__format_symbol(asset_id)
|
||||||
|
|
||||||
ticker = self.client.get_symbol_ticker(symbol=symbol)
|
ticker: dict[str, Any] = self.client.get_symbol_ticker(symbol=symbol) # type: ignore
|
||||||
ticker_24h = self.client.get_ticker(symbol=symbol)
|
ticker_24h: dict[str, Any] = self.client.get_ticker(symbol=symbol) # type: ignore
|
||||||
ticker['volume'] = ticker_24h.get('volume', 0) # Aggiunge volume 24h ai dati del ticker
|
ticker['volume'] = ticker_24h.get('volume', 0)
|
||||||
|
|
||||||
return extract_product(self.currency, ticker)
|
return extract_product(self.currency, ticker)
|
||||||
|
|
||||||
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
|
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
|
||||||
symbols = [self.__format_symbol(asset_id) for asset_id in asset_ids]
|
return [ self.get_product(asset_id) for asset_id in asset_ids ]
|
||||||
symbols_str = f"[\"{'","'.join(symbols)}\"]"
|
|
||||||
|
|
||||||
tickers = self.client.get_symbol_ticker(symbols=symbols_str)
|
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
|
||||||
tickers_24h = self.client.get_ticker(symbols=symbols_str) # un po brutale, ma va bene così
|
|
||||||
for t, t24 in zip(tickers, tickers_24h):
|
|
||||||
t['volume'] = t24.get('volume', 0)
|
|
||||||
|
|
||||||
return [extract_product(self.currency, ticker) for ticker in tickers]
|
|
||||||
|
|
||||||
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
|
|
||||||
symbol = self.__format_symbol(asset_id)
|
symbol = self.__format_symbol(asset_id)
|
||||||
|
|
||||||
# Ottiene candele orarie degli ultimi 30 giorni
|
# Ottiene candele orarie degli ultimi 30 giorni
|
||||||
klines = self.client.get_historical_klines(
|
klines: list[list[Any]] = self.client.get_historical_klines( # type: ignore
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
interval=Client.KLINE_INTERVAL_1HOUR,
|
interval=Client.KLINE_INTERVAL_1HOUR,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
return [extract_price(kline) for kline in klines]
|
return [extract_price(kline) for kline in klines]
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from coinbase.rest import RESTClient
|
from coinbase.rest import RESTClient # type: ignore
|
||||||
from coinbase.rest.types.product_types import Candle, GetProductResponse, Product
|
from coinbase.rest.types.product_types import Candle, GetProductResponse, Product # type: ignore
|
||||||
from app.markets.base import ProductInfo, BaseWrapper, Price
|
from app.markets.base import ProductInfo, MarketWrapper, Price
|
||||||
|
|
||||||
|
|
||||||
def extract_product(product_data: GetProductResponse | Product) -> ProductInfo:
|
def extract_product(product_data: GetProductResponse | Product) -> ProductInfo:
|
||||||
@@ -37,7 +37,7 @@ class Granularity(Enum):
|
|||||||
SIX_HOUR = 21600
|
SIX_HOUR = 21600
|
||||||
ONE_DAY = 86400
|
ONE_DAY = 86400
|
||||||
|
|
||||||
class CoinBaseWrapper(BaseWrapper):
|
class CoinBaseWrapper(MarketWrapper):
|
||||||
"""
|
"""
|
||||||
Wrapper per le API di Coinbase Advanced Trade.\n
|
Wrapper per le API di Coinbase Advanced Trade.\n
|
||||||
Implementa l'interfaccia BaseWrapper per fornire accesso unificato
|
Implementa l'interfaccia BaseWrapper per fornire accesso unificato
|
||||||
@@ -63,24 +63,26 @@ class CoinBaseWrapper(BaseWrapper):
|
|||||||
|
|
||||||
def get_product(self, asset_id: str) -> ProductInfo:
|
def get_product(self, asset_id: str) -> ProductInfo:
|
||||||
asset_id = self.__format(asset_id)
|
asset_id = self.__format(asset_id)
|
||||||
asset = self.client.get_product(asset_id)
|
asset = self.client.get_product(asset_id) # type: ignore
|
||||||
return extract_product(asset)
|
return extract_product(asset)
|
||||||
|
|
||||||
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
|
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
|
||||||
all_asset_ids = [self.__format(asset_id) for asset_id in asset_ids]
|
all_asset_ids = [self.__format(asset_id) for asset_id in asset_ids]
|
||||||
assets = self.client.get_products(product_ids=all_asset_ids)
|
assets = self.client.get_products(product_ids=all_asset_ids) # type: ignore
|
||||||
|
assert assets.products is not None, "No products data received from Coinbase"
|
||||||
return [extract_product(asset) for asset in assets.products]
|
return [extract_product(asset) for asset in assets.products]
|
||||||
|
|
||||||
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
|
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
|
||||||
asset_id = self.__format(asset_id)
|
asset_id = self.__format(asset_id)
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
start_time = end_time - timedelta(days=14)
|
start_time = end_time - timedelta(days=14)
|
||||||
|
|
||||||
data = self.client.get_candles(
|
data = self.client.get_candles( # type: ignore
|
||||||
product_id=asset_id,
|
product_id=asset_id,
|
||||||
granularity=Granularity.ONE_HOUR.name,
|
granularity=Granularity.ONE_HOUR.name,
|
||||||
start=str(int(start_time.timestamp())),
|
start=str(int(start_time.timestamp())),
|
||||||
end=str(int(end_time.timestamp())),
|
end=str(int(end_time.timestamp())),
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
|
assert data.candles is not None, "No candles data received from Coinbase"
|
||||||
return [extract_price(candle) for candle in data.candles]
|
return [extract_price(candle) for candle in data.candles]
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Any
|
||||||
import requests
|
import requests
|
||||||
from app.markets.base import ProductInfo, BaseWrapper, Price
|
from app.markets.base import ProductInfo, MarketWrapper, Price
|
||||||
|
|
||||||
|
|
||||||
def extract_product(asset_data: dict) -> ProductInfo:
|
def extract_product(asset_data: dict[str, Any]) -> ProductInfo:
|
||||||
product = ProductInfo()
|
product = ProductInfo()
|
||||||
product.id = asset_data.get('FROMSYMBOL', '') + '-' + asset_data.get('TOSYMBOL', '')
|
product.id = asset_data.get('FROMSYMBOL', '') + '-' + asset_data.get('TOSYMBOL', '')
|
||||||
product.symbol = asset_data.get('FROMSYMBOL', '')
|
product.symbol = asset_data.get('FROMSYMBOL', '')
|
||||||
@@ -12,7 +13,7 @@ def extract_product(asset_data: dict) -> ProductInfo:
|
|||||||
assert product.price > 0, "Invalid price data received from CryptoCompare"
|
assert product.price > 0, "Invalid price data received from CryptoCompare"
|
||||||
return product
|
return product
|
||||||
|
|
||||||
def extract_price(price_data: dict) -> Price:
|
def extract_price(price_data: dict[str, Any]) -> Price:
|
||||||
price = Price()
|
price = Price()
|
||||||
price.high = float(price_data.get('high', 0))
|
price.high = float(price_data.get('high', 0))
|
||||||
price.low = float(price_data.get('low', 0))
|
price.low = float(price_data.get('low', 0))
|
||||||
@@ -26,7 +27,7 @@ def extract_price(price_data: dict) -> Price:
|
|||||||
|
|
||||||
BASE_URL = "https://min-api.cryptocompare.com"
|
BASE_URL = "https://min-api.cryptocompare.com"
|
||||||
|
|
||||||
class CryptoCompareWrapper(BaseWrapper):
|
class CryptoCompareWrapper(MarketWrapper):
|
||||||
"""
|
"""
|
||||||
Wrapper per le API pubbliche di CryptoCompare.
|
Wrapper per le API pubbliche di CryptoCompare.
|
||||||
La documentazione delle API è disponibile qui: https://developers.coindesk.com/documentation/legacy/Price/SingleSymbolPriceEndpoint
|
La documentazione delle API è disponibile qui: https://developers.coindesk.com/documentation/legacy/Price/SingleSymbolPriceEndpoint
|
||||||
@@ -39,7 +40,7 @@ class CryptoCompareWrapper(BaseWrapper):
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.currency = currency
|
self.currency = currency
|
||||||
|
|
||||||
def __request(self, endpoint: str, params: dict[str, str] | None = None) -> dict[str, str]:
|
def __request(self, endpoint: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
params['api_key'] = self.api_key
|
params['api_key'] = self.api_key
|
||||||
@@ -60,7 +61,7 @@ class CryptoCompareWrapper(BaseWrapper):
|
|||||||
"fsyms": ",".join(asset_ids),
|
"fsyms": ",".join(asset_ids),
|
||||||
"tsyms": self.currency
|
"tsyms": self.currency
|
||||||
})
|
})
|
||||||
assets = []
|
assets: list[ProductInfo] = []
|
||||||
data = response.get('RAW', {})
|
data = response.get('RAW', {})
|
||||||
for asset_id in asset_ids:
|
for asset_id in asset_ids:
|
||||||
asset_data = data.get(asset_id, {}).get(self.currency, {})
|
asset_data = data.get(asset_id, {}).get(self.currency, {})
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from agno.tools.yfinance import YFinanceTools
|
from agno.tools.yfinance import YFinanceTools
|
||||||
from app.markets.base import BaseWrapper, ProductInfo, Price
|
from app.markets.base import MarketWrapper, ProductInfo, Price
|
||||||
|
|
||||||
|
|
||||||
def extract_product(stock_data: dict[str, str]) -> ProductInfo:
|
def extract_product(stock_data: dict[str, str]) -> ProductInfo:
|
||||||
@@ -29,7 +29,7 @@ def extract_price(hist_data: dict[str, str]) -> Price:
|
|||||||
return price
|
return price
|
||||||
|
|
||||||
|
|
||||||
class YFinanceWrapper(BaseWrapper):
|
class YFinanceWrapper(MarketWrapper):
|
||||||
"""
|
"""
|
||||||
Wrapper per YFinanceTools che fornisce dati di mercato per azioni, ETF e criptovalute.
|
Wrapper per YFinanceTools che fornisce dati di mercato per azioni, ETF e criptovalute.
|
||||||
Implementa l'interfaccia BaseWrapper per compatibilità con il sistema esistente.
|
Implementa l'interfaccia BaseWrapper per compatibilità con il sistema esistente.
|
||||||
@@ -55,13 +55,13 @@ class YFinanceWrapper(BaseWrapper):
|
|||||||
return extract_product(stock_info)
|
return extract_product(stock_info)
|
||||||
|
|
||||||
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
|
def get_products(self, asset_ids: list[str]) -> list[ProductInfo]:
|
||||||
products = []
|
products: list[ProductInfo] = []
|
||||||
for asset_id in asset_ids:
|
for asset_id in asset_ids:
|
||||||
product = self.get_product(asset_id)
|
product = self.get_product(asset_id)
|
||||||
products.append(product)
|
products.append(product)
|
||||||
return products
|
return products
|
||||||
|
|
||||||
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> list[Price]:
|
def get_historical_prices(self, asset_id: str, limit: int = 100) -> list[Price]:
|
||||||
symbol = self._format_symbol(asset_id)
|
symbol = self._format_symbol(asset_id)
|
||||||
|
|
||||||
days = limit // 24 + 1 # Arrotonda per eccesso
|
days = limit // 24 + 1 # Arrotonda per eccesso
|
||||||
@@ -71,7 +71,7 @@ class YFinanceWrapper(BaseWrapper):
|
|||||||
# Il formato dei dati è {timestamp: {Open: x, High: y, Low: z, Close: w, Volume: v}}
|
# Il formato dei dati è {timestamp: {Open: x, High: y, Low: z, Close: w, Volume: v}}
|
||||||
timestamps = sorted(hist_data.keys())[-limit:]
|
timestamps = sorted(hist_data.keys())[-limit:]
|
||||||
|
|
||||||
prices = []
|
prices: list[Price] = []
|
||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
temp = hist_data[timestamp]
|
temp = hist_data[timestamp]
|
||||||
temp['Timestamp'] = timestamp
|
temp['Timestamp'] = timestamp
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ import traceback
|
|||||||
from typing import Any, Callable, Generic, TypeVar
|
from typing import Any, Callable, Generic, TypeVar
|
||||||
from agno.utils.log import log_info, log_warning #type: ignore
|
from agno.utils.log import log_info, log_warning #type: ignore
|
||||||
|
|
||||||
W = TypeVar("W")
|
WrapperType = TypeVar("WrapperType")
|
||||||
T = TypeVar("T")
|
WrapperClassType = TypeVar("WrapperClassType")
|
||||||
|
OutputType = TypeVar("OutputType")
|
||||||
|
|
||||||
|
|
||||||
class WrapperHandler(Generic[W]):
|
class WrapperHandler(Generic[WrapperType]):
|
||||||
"""
|
"""
|
||||||
A handler for managing multiple wrappers with retry logic.
|
A handler for managing multiple wrappers with retry logic.
|
||||||
It attempts to call a function on the current wrapper, and if it fails,
|
It attempts to call a function on the current wrapper, and if it fails,
|
||||||
@@ -18,7 +19,7 @@ class WrapperHandler(Generic[W]):
|
|||||||
Note: use `build_wrappers` to create an instance of this class for better error handling.
|
Note: use `build_wrappers` to create an instance of this class for better error handling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, wrappers: list[W], try_per_wrapper: int = 3, retry_delay: int = 2):
|
def __init__(self, wrappers: list[WrapperType], try_per_wrapper: int = 3, retry_delay: int = 2):
|
||||||
"""
|
"""
|
||||||
Initializes the WrapperHandler with a list of wrappers and retry settings.\n
|
Initializes the WrapperHandler with a list of wrappers and retry settings.\n
|
||||||
Use `build_wrappers` to create an instance of this class for better error handling.
|
Use `build_wrappers` to create an instance of this class for better error handling.
|
||||||
@@ -35,7 +36,7 @@ class WrapperHandler(Generic[W]):
|
|||||||
self.index = 0
|
self.index = 0
|
||||||
self.retry_count = 0
|
self.retry_count = 0
|
||||||
|
|
||||||
def try_call(self, func: Callable[[W], T]) -> T:
|
def try_call(self, func: Callable[[WrapperType], OutputType]) -> OutputType:
|
||||||
"""
|
"""
|
||||||
Attempts to call the provided function on the current wrapper.
|
Attempts to call the provided function on the current wrapper.
|
||||||
If it fails, it retries a specified number of times before switching to the next wrapper.
|
If it fails, it retries a specified number of times before switching to the next wrapper.
|
||||||
@@ -76,7 +77,7 @@ class WrapperHandler(Generic[W]):
|
|||||||
|
|
||||||
raise Exception(f"All wrappers failed, latest error: {error}")
|
raise Exception(f"All wrappers failed, latest error: {error}")
|
||||||
|
|
||||||
def try_call_all(self, func: Callable[[W], T]) -> dict[str, T]:
|
def try_call_all(self, func: Callable[[WrapperType], OutputType]) -> dict[str, OutputType]:
|
||||||
"""
|
"""
|
||||||
Calls the provided function on all wrappers, collecting results.
|
Calls the provided function on all wrappers, collecting results.
|
||||||
If a wrapper fails, it logs a warning and continues with the next.
|
If a wrapper fails, it logs a warning and continues with the next.
|
||||||
@@ -90,7 +91,7 @@ class WrapperHandler(Generic[W]):
|
|||||||
"""
|
"""
|
||||||
log_info(f"{inspect.getsource(func).strip()} {inspect.getclosurevars(func).nonlocals}")
|
log_info(f"{inspect.getsource(func).strip()} {inspect.getclosurevars(func).nonlocals}")
|
||||||
|
|
||||||
results: dict[str, T] = {}
|
results: dict[str, OutputType] = {}
|
||||||
error = ""
|
error = ""
|
||||||
for wrapper in self.wrappers:
|
for wrapper in self.wrappers:
|
||||||
wrapper_name = wrapper.__class__.__name__
|
wrapper_name = wrapper.__class__.__name__
|
||||||
@@ -115,7 +116,7 @@ class WrapperHandler(Generic[W]):
|
|||||||
return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]"
|
return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_wrappers(constructors: list[type[W]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[W]':
|
def build_wrappers(constructors: list[type[WrapperClassType]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[WrapperClassType]':
|
||||||
"""
|
"""
|
||||||
Builds a WrapperHandler instance with the given wrapper constructors.
|
Builds a WrapperHandler instance with the given wrapper constructors.
|
||||||
It attempts to initialize each wrapper and logs a warning if any cannot be initialized.
|
It attempts to initialize each wrapper and logs a warning if any cannot be initialized.
|
||||||
@@ -132,7 +133,7 @@ class WrapperHandler(Generic[W]):
|
|||||||
"""
|
"""
|
||||||
assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}"
|
assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}"
|
||||||
|
|
||||||
result: list[W] = []
|
result: list[WrapperClassType] = []
|
||||||
for wrapper_class in constructors:
|
for wrapper_class in constructors:
|
||||||
try:
|
try:
|
||||||
wrapper = wrapper_class(**(kwargs or {}))
|
wrapper = wrapper_class(**(kwargs or {}))
|
||||||
|
|||||||
Reference in New Issue
Block a user