12 fix docs (#13)
* fix dependencies uv.lock * refactor test markers for clarity * refactor: clean up imports and remove unused files * refactor: remove unused agent files and clean up market API instructions * refactor: enhance wrapper initialization with keyword arguments and clean up tests * refactor: remove PublicBinanceAgent * refactor: aggregator - simplified MarketDataAggregator and related models to functions * refactor: update README and .env.example to reflect the latest changes to the project * refactor: simplify product info and price creation in YFinanceWrapper * refactor: remove get_all_products method from market API wrappers and update documentation * fix: environment variable assertions * refactor: remove status attribute from ProductInfo and update related methods to use timestamp_ms * feat: implement aggregate_history_prices function to calculate hourly price averages * refactor: update docker-compose and app.py for improved environment variable handling and compatibility * feat: add detailed market instructions and improve error handling in price aggregation methods * feat: add aggregated news retrieval methods for top headlines and latest news * refactor: improve error messages in WrapperHandler for better clarity * fix: correct quote currency extraction in create_product_info and remove debug prints from tests
This commit was merged in pull request #13.
This commit is contained in:
committed by
GitHub
parent
a8755913d8
commit
d2fbc0ceea
@@ -1,186 +0,0 @@
|
||||
import statistics
|
||||
from typing import Dict, List, Optional, Set
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from app.markets.base import ProductInfo
|
||||
|
||||
class AggregationMetadata(BaseModel):
|
||||
"""Metadati nascosti per debugging e audit trail"""
|
||||
sources_used: Set[str] = Field(default_factory=set, description="Exchange usati nell'aggregazione")
|
||||
sources_ignored: Set[str] = Field(default_factory=set, description="Exchange ignorati (errori)")
|
||||
aggregation_timestamp: str = Field(default="", description="Timestamp dell'aggregazione")
|
||||
confidence_score: float = Field(default=0.0, description="Score 0-1 sulla qualità dei dati")
|
||||
|
||||
class Config:
|
||||
# Nasconde questi campi dalla serializzazione di default
|
||||
extra = "forbid"
|
||||
|
||||
class AggregatedProductInfo(ProductInfo):
|
||||
"""
|
||||
Versione aggregata di ProductInfo che mantiene la trasparenza per l'utente finale
|
||||
mentre fornisce metadati di debugging opzionali.
|
||||
"""
|
||||
|
||||
# Override dei campi con logica di aggregazione
|
||||
id: str = Field(description="ID aggregato basato sul simbolo standardizzato")
|
||||
status: str = Field(description="Status aggregato (majority vote o conservative)")
|
||||
|
||||
# Campi privati per debugging (non visibili di default)
|
||||
_metadata: Optional[AggregationMetadata] = PrivateAttr(default=None)
|
||||
_source_data: Optional[Dict[str, ProductInfo]] = PrivateAttr(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_multiple_sources(cls, products: List[ProductInfo]) -> 'AggregatedProductInfo':
|
||||
"""
|
||||
Crea un AggregatedProductInfo da una lista di ProductInfo.
|
||||
Usa strategie intelligenti per gestire ID e status.
|
||||
"""
|
||||
if not products:
|
||||
raise ValueError("Nessun prodotto da aggregare")
|
||||
|
||||
# Raggruppa per symbol (la chiave vera per l'aggregazione)
|
||||
symbol_groups = {}
|
||||
for product in products:
|
||||
if product.symbol not in symbol_groups:
|
||||
symbol_groups[product.symbol] = []
|
||||
symbol_groups[product.symbol].append(product)
|
||||
|
||||
# Per ora gestiamo un symbol alla volta
|
||||
if len(symbol_groups) > 1:
|
||||
raise ValueError(f"Simboli multipli non supportati: {list(symbol_groups.keys())}")
|
||||
|
||||
symbol_products = list(symbol_groups.values())[0]
|
||||
|
||||
# Estrai tutte le fonti
|
||||
sources = []
|
||||
for product in symbol_products:
|
||||
# Determina la fonte dall'ID o da altri metadati se disponibili
|
||||
source = cls._detect_source(product)
|
||||
sources.append(source)
|
||||
|
||||
# Aggrega i dati
|
||||
aggregated_data = cls._aggregate_products(symbol_products, sources)
|
||||
|
||||
# Crea l'istanza e assegna gli attributi privati
|
||||
instance = cls(**aggregated_data)
|
||||
instance._metadata = aggregated_data.get("_metadata")
|
||||
instance._source_data = aggregated_data.get("_source_data")
|
||||
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def _detect_source(product: ProductInfo) -> str:
|
||||
"""Rileva la fonte da un ProductInfo"""
|
||||
# Strategia semplice: usa pattern negli ID
|
||||
if "coinbase" in product.id.lower() or "cb" in product.id.lower():
|
||||
return "coinbase"
|
||||
elif "binance" in product.id.lower() or "bn" in product.id.lower():
|
||||
return "binance"
|
||||
elif "crypto" in product.id.lower() or "cc" in product.id.lower():
|
||||
return "cryptocompare"
|
||||
elif "yfinance" in product.id.lower() or "yf" in product.id.lower():
|
||||
return "yfinance"
|
||||
else:
|
||||
return "unknown"
|
||||
|
||||
@classmethod
|
||||
def _aggregate_products(cls, products: List[ProductInfo], sources: List[str]) -> dict:
|
||||
"""
|
||||
Logica di aggregazione principale.
|
||||
Gestisce ID, status e altri campi numerici.
|
||||
"""
|
||||
import statistics
|
||||
from datetime import datetime
|
||||
|
||||
# ID: usa il symbol come chiave standardizzata
|
||||
symbol = products[0].symbol
|
||||
aggregated_id = f"{symbol}_AGG"
|
||||
|
||||
# Status: strategia "conservativa" - il più restrittivo vince
|
||||
# Ordine: trading_only < limit_only < auction < maintenance < offline
|
||||
status_priority = {
|
||||
"trading": 1,
|
||||
"limit_only": 2,
|
||||
"auction": 3,
|
||||
"maintenance": 4,
|
||||
"offline": 5,
|
||||
"": 0 # Default se non specificato
|
||||
}
|
||||
|
||||
statuses = [p.status for p in products if p.status]
|
||||
if statuses:
|
||||
# Prendi lo status con priorità più alta (più restrittivo)
|
||||
aggregated_status = max(statuses, key=lambda s: status_priority.get(s, 0))
|
||||
else:
|
||||
aggregated_status = "trading" # Default ottimistico
|
||||
|
||||
# Prezzo: media semplice (uso diretto del campo price come float)
|
||||
prices = [p.price for p in products if p.price > 0]
|
||||
aggregated_price = statistics.mean(prices) if prices else 0.0
|
||||
|
||||
# Volume: somma (assumendo che i volumi siano esclusivi per exchange)
|
||||
volumes = [p.volume_24h for p in products if p.volume_24h > 0]
|
||||
total_volume = sum(volumes)
|
||||
aggregated_volume = sum(price_i * volume_i for price_i, volume_i in zip((p.price for p in products), (volume for volume in volumes))) / total_volume
|
||||
aggregated_volume = round(aggregated_volume, 5)
|
||||
# aggregated_volume = sum(volumes) if volumes else 0.0 # NOTE old implementation
|
||||
|
||||
# Valuta: prendi la prima (dovrebbero essere tutte uguali)
|
||||
quote_currency = next((p.quote_currency for p in products if p.quote_currency), "USD")
|
||||
|
||||
# Calcola confidence score
|
||||
confidence = cls._calculate_confidence(products, sources)
|
||||
|
||||
# Crea metadati per debugging
|
||||
metadata = AggregationMetadata(
|
||||
sources_used=set(sources),
|
||||
aggregation_timestamp=datetime.now().isoformat(),
|
||||
confidence_score=confidence
|
||||
)
|
||||
|
||||
# Salva dati sorgente per debugging
|
||||
source_data = dict(zip(sources, products))
|
||||
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"price": aggregated_price,
|
||||
"volume_24h": aggregated_volume,
|
||||
"quote_currency": quote_currency,
|
||||
"id": aggregated_id,
|
||||
"status": aggregated_status,
|
||||
"_metadata": metadata,
|
||||
"_source_data": source_data
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _calculate_confidence(products: List[ProductInfo], sources: List[str]) -> float:
|
||||
"""Calcola un punteggio di confidenza 0-1"""
|
||||
if not products:
|
||||
return 0.0
|
||||
|
||||
score = 1.0
|
||||
|
||||
# Riduci score se pochi dati
|
||||
if len(products) < 2:
|
||||
score *= 0.7
|
||||
|
||||
# Riduci score se prezzi troppo diversi
|
||||
prices = [p.price for p in products if p.price > 0]
|
||||
if len(prices) > 1:
|
||||
price_std = (max(prices) - min(prices)) / statistics.mean(prices)
|
||||
if price_std > 0.05: # >5% variazione
|
||||
score *= 0.8
|
||||
|
||||
# Riduci score se fonti sconosciute
|
||||
unknown_sources = sum(1 for s in sources if s == "unknown")
|
||||
if unknown_sources > 0:
|
||||
score *= (1 - unknown_sources / len(sources))
|
||||
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
def get_debug_info(self) -> dict:
|
||||
"""Metodo opzionale per ottenere informazioni di debug"""
|
||||
return {
|
||||
"aggregated_product": self.dict(),
|
||||
"metadata": self._metadata.dict() if self._metadata else None,
|
||||
"sources": list(self._source_data.keys()) if self._source_data else []
|
||||
}
|
||||
91
src/app/utils/market_aggregation.py
Normal file
91
src/app/utils/market_aggregation.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import statistics
|
||||
from app.markets.base import ProductInfo, Price
|
||||
|
||||
|
||||
def aggregate_history_prices(prices: dict[str, list[Price]]) -> list[Price]:
|
||||
"""
|
||||
Aggrega i prezzi storici per symbol calcolando la media oraria.
|
||||
Args:
|
||||
prices (dict[str, list[Price]]): Mappa provider -> lista di Price
|
||||
Returns:
|
||||
list[Price]: Lista di Price aggregati per ora
|
||||
"""
|
||||
|
||||
# Costruiamo una mappa timestamp_h -> lista di Price
|
||||
timestamped_prices: dict[int, list[Price]] = {}
|
||||
for _, price_list in prices.items():
|
||||
for price in price_list:
|
||||
time = price.timestamp_ms - (price.timestamp_ms % 3600000) # arrotonda all'ora (non dovrebbe essere necessario)
|
||||
timestamped_prices.setdefault(time, []).append(price)
|
||||
|
||||
# Ora aggregiamo i prezzi per ogni ora
|
||||
aggregated_prices = []
|
||||
for time, price_list in timestamped_prices.items():
|
||||
price = Price()
|
||||
price.timestamp_ms = time
|
||||
price.high = statistics.mean([p.high for p in price_list])
|
||||
price.low = statistics.mean([p.low for p in price_list])
|
||||
price.open = statistics.mean([p.open for p in price_list])
|
||||
price.close = statistics.mean([p.close for p in price_list])
|
||||
price.volume = statistics.mean([p.volume for p in price_list])
|
||||
aggregated_prices.append(price)
|
||||
return aggregated_prices
|
||||
|
||||
def aggregate_product_info(products: dict[str, list[ProductInfo]]) -> list[ProductInfo]:
|
||||
"""
|
||||
Aggrega una lista di ProductInfo per symbol.
|
||||
Args:
|
||||
products (dict[str, list[ProductInfo]]): Mappa provider -> lista di ProductInfo
|
||||
Returns:
|
||||
list[ProductInfo]: Lista di ProductInfo aggregati per symbol
|
||||
"""
|
||||
|
||||
# Costruzione mappa symbol -> lista di ProductInfo
|
||||
symbols_infos: dict[str, list[ProductInfo]] = {}
|
||||
for _, product_list in products.items():
|
||||
for product in product_list:
|
||||
symbols_infos.setdefault(product.symbol, []).append(product)
|
||||
|
||||
# Aggregazione per ogni symbol
|
||||
sources = list(products.keys())
|
||||
aggregated_products = []
|
||||
for symbol, product_list in symbols_infos.items():
|
||||
product = ProductInfo()
|
||||
|
||||
product.id = f"{symbol}_AGGREGATED"
|
||||
product.symbol = symbol
|
||||
product.quote_currency = next(p.quote_currency for p in product_list if p.quote_currency)
|
||||
|
||||
volume_sum = sum(p.volume_24h for p in product_list)
|
||||
product.volume_24h = volume_sum / len(product_list) if product_list else 0.0
|
||||
|
||||
prices = sum(p.price * p.volume_24h for p in product_list)
|
||||
product.price = (prices / volume_sum) if volume_sum > 0 else 0.0
|
||||
|
||||
aggregated_products.append(product)
|
||||
return aggregated_products
|
||||
|
||||
def _calculate_confidence(products: list[ProductInfo], sources: list[str]) -> float:
|
||||
"""Calcola un punteggio di confidenza 0-1"""
|
||||
if not products:
|
||||
return 0.0
|
||||
|
||||
score = 1.0
|
||||
|
||||
# Riduci score se pochi dati
|
||||
if len(products) < 2:
|
||||
score *= 0.7
|
||||
|
||||
# Riduci score se prezzi troppo diversi
|
||||
prices = [p.price for p in products if p.price > 0]
|
||||
if len(prices) > 1:
|
||||
price_std = (max(prices) - min(prices)) / statistics.mean(prices)
|
||||
if price_std > 0.05: # >5% variazione
|
||||
score *= 0.8
|
||||
|
||||
# Riduci score se fonti sconosciute
|
||||
unknown_sources = sum(1 for s in sources if s == "unknown")
|
||||
if unknown_sources > 0:
|
||||
score *= (1 - unknown_sources / len(sources))
|
||||
|
||||
return max(0.0, min(1.0, score))
|
||||
@@ -1,184 +0,0 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from app.markets.base import ProductInfo, Price
|
||||
from app.utils.aggregated_models import AggregatedProductInfo
|
||||
|
||||
class MarketDataAggregator:
|
||||
"""
|
||||
Aggregatore di dati di mercato che mantiene la trasparenza per l'utente.
|
||||
|
||||
Compone MarketAPIs per fornire gli stessi metodi, ma restituisce dati aggregati
|
||||
da tutte le fonti disponibili. L'utente finale non vede la complessità.
|
||||
"""
|
||||
|
||||
def __init__(self, currency: str = "USD"):
|
||||
# Import lazy per evitare circular import
|
||||
from app.markets import MarketAPIsTool
|
||||
self._market_apis = MarketAPIsTool(currency)
|
||||
self._aggregation_enabled = True
|
||||
|
||||
def get_product(self, asset_id: str) -> ProductInfo:
|
||||
"""
|
||||
Override che aggrega dati da tutte le fonti disponibili.
|
||||
Per l'utente sembra un normale ProductInfo.
|
||||
"""
|
||||
if not self._aggregation_enabled:
|
||||
return self._market_apis.get_product(asset_id)
|
||||
|
||||
# Raccogli dati da tutte le fonti
|
||||
try:
|
||||
raw_results = self.wrappers.try_call_all(
|
||||
lambda wrapper: wrapper.get_product(asset_id)
|
||||
)
|
||||
|
||||
# Converti in ProductInfo se necessario
|
||||
products = []
|
||||
for wrapper_class, result in raw_results.items():
|
||||
if isinstance(result, ProductInfo):
|
||||
products.append(result)
|
||||
elif isinstance(result, dict):
|
||||
# Converti dizionario in ProductInfo
|
||||
products.append(ProductInfo(**result))
|
||||
|
||||
if not products:
|
||||
raise Exception("Nessun dato disponibile")
|
||||
|
||||
# Aggrega i risultati
|
||||
aggregated = AggregatedProductInfo.from_multiple_sources(products)
|
||||
|
||||
# Restituisci come ProductInfo normale (nascondi la complessità)
|
||||
return ProductInfo(**aggregated.dict(exclude={"_metadata", "_source_data"}))
|
||||
|
||||
except Exception as e:
|
||||
# Fallback: usa il comportamento normale se l'aggregazione fallisce
|
||||
return self._market_apis.get_product(asset_id)
|
||||
|
||||
def get_products(self, asset_ids: List[str]) -> List[ProductInfo]:
|
||||
"""
|
||||
Aggrega dati per multiple asset.
|
||||
"""
|
||||
if not self._aggregation_enabled:
|
||||
return self._market_apis.get_products(asset_ids)
|
||||
|
||||
aggregated_products = []
|
||||
|
||||
for asset_id in asset_ids:
|
||||
try:
|
||||
product = self.get_product(asset_id)
|
||||
aggregated_products.append(product)
|
||||
except Exception as e:
|
||||
# Salta asset che non riescono ad aggregare
|
||||
continue
|
||||
|
||||
return aggregated_products
|
||||
|
||||
def get_all_products(self) -> List[ProductInfo]:
|
||||
"""
|
||||
Aggrega tutti i prodotti disponibili.
|
||||
"""
|
||||
if not self._aggregation_enabled:
|
||||
return self._market_apis.get_all_products()
|
||||
|
||||
# Raccogli tutti i prodotti da tutte le fonti
|
||||
try:
|
||||
all_products_by_source = self.wrappers.try_call_all(
|
||||
lambda wrapper: wrapper.get_all_products()
|
||||
)
|
||||
|
||||
# Raggruppa per symbol per aggregare
|
||||
symbol_groups = {}
|
||||
for wrapper_class, products in all_products_by_source.items():
|
||||
if not isinstance(products, list):
|
||||
continue
|
||||
|
||||
for product in products:
|
||||
if isinstance(product, dict):
|
||||
product = ProductInfo(**product)
|
||||
|
||||
if product.symbol not in symbol_groups:
|
||||
symbol_groups[product.symbol] = []
|
||||
symbol_groups[product.symbol].append(product)
|
||||
|
||||
# Aggrega ogni gruppo
|
||||
aggregated_products = []
|
||||
for symbol, products in symbol_groups.items():
|
||||
try:
|
||||
aggregated = AggregatedProductInfo.from_multiple_sources(products)
|
||||
# Restituisci come ProductInfo normale
|
||||
aggregated_products.append(
|
||||
ProductInfo(**aggregated.dict(exclude={"_metadata", "_source_data"}))
|
||||
)
|
||||
except Exception:
|
||||
# Se l'aggregazione fallisce, usa il primo disponibile
|
||||
if products:
|
||||
aggregated_products.append(products[0])
|
||||
|
||||
return aggregated_products
|
||||
|
||||
except Exception as e:
|
||||
# Fallback: usa il comportamento normale
|
||||
return self._market_apis.get_all_products()
|
||||
|
||||
def get_historical_prices(self, asset_id: str = "BTC", limit: int = 100) -> List[Price]:
|
||||
"""
|
||||
Per i dati storici, usa una strategia diversa:
|
||||
prendi i dati dalla fonte più affidabile o aggrega se possibile.
|
||||
"""
|
||||
if not self._aggregation_enabled:
|
||||
return self._market_apis.get_historical_prices(asset_id, limit)
|
||||
|
||||
# Per dati storici, usa il primo wrapper che funziona
|
||||
# (l'aggregazione di dati storici è più complessa)
|
||||
try:
|
||||
return self.wrappers.try_call(
|
||||
lambda wrapper: wrapper.get_historical_prices(asset_id, limit)
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback: usa il comportamento normale
|
||||
return self._market_apis.get_historical_prices(asset_id, limit)
|
||||
|
||||
def enable_aggregation(self, enabled: bool = True):
|
||||
"""Abilita o disabilita l'aggregazione"""
|
||||
self._aggregation_enabled = enabled
|
||||
|
||||
def is_aggregation_enabled(self) -> bool:
|
||||
"""Controlla se l'aggregazione è abilitata"""
|
||||
return self._aggregation_enabled
|
||||
|
||||
# Metodi proxy per completare l'interfaccia BaseWrapper
|
||||
@property
|
||||
def wrappers(self):
|
||||
"""Accesso al wrapper handler per compatibilità"""
|
||||
return self._market_apis.wrappers
|
||||
|
||||
def get_aggregated_product_with_debug(self, asset_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Metodo speciale per debugging: restituisce dati aggregati con metadati.
|
||||
Usato solo per testing e monitoraggio.
|
||||
"""
|
||||
try:
|
||||
raw_results = self.wrappers.try_call_all(
|
||||
lambda wrapper: wrapper.get_product(asset_id)
|
||||
)
|
||||
|
||||
products = []
|
||||
for wrapper_class, result in raw_results.items():
|
||||
if isinstance(result, ProductInfo):
|
||||
products.append(result)
|
||||
elif isinstance(result, dict):
|
||||
products.append(ProductInfo(**result))
|
||||
|
||||
if not products:
|
||||
raise Exception("Nessun dato disponibile")
|
||||
|
||||
aggregated = AggregatedProductInfo.from_multiple_sources(products)
|
||||
|
||||
return {
|
||||
"product": aggregated.dict(exclude={"_metadata", "_source_data"}),
|
||||
"debug": aggregated.get_debug_info()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"debug": {"error": str(e)}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import time
|
||||
import traceback
|
||||
from typing import TypeVar, Callable, Generic, Iterable, Type
|
||||
@@ -45,17 +46,24 @@ class WrapperHandler(Generic[W]):
|
||||
Raises:
|
||||
Exception: If all wrappers fail after retries.
|
||||
"""
|
||||
log_info(f"{inspect.getsource(func).strip()} {inspect.getclosurevars(func).nonlocals}")
|
||||
|
||||
iterations = 0
|
||||
while iterations < len(self.wrappers):
|
||||
wrapper = self.wrappers[self.index]
|
||||
wrapper_name = wrapper.__class__.__name__
|
||||
|
||||
try:
|
||||
wrapper = self.wrappers[self.index]
|
||||
log_info(f"Trying wrapper: {wrapper} - function {func}")
|
||||
log_info(f"try_call {wrapper_name}")
|
||||
result = func(wrapper)
|
||||
log_info(f"{wrapper_name} succeeded")
|
||||
self.retry_count = 0
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.retry_count += 1
|
||||
log_warning(f"{wrapper} failed {self.retry_count}/{self.retry_per_wrapper}: {WrapperHandler.__concise_error(e)}")
|
||||
error = WrapperHandler.__concise_error(e)
|
||||
log_warning(f"{wrapper_name} failed {self.retry_count}/{self.retry_per_wrapper}: {error}")
|
||||
|
||||
if self.retry_count >= self.retry_per_wrapper:
|
||||
self.index = (self.index + 1) % len(self.wrappers)
|
||||
@@ -64,7 +72,7 @@ class WrapperHandler(Generic[W]):
|
||||
else:
|
||||
time.sleep(self.retry_delay)
|
||||
|
||||
raise Exception(f"All wrappers failed after retries")
|
||||
raise Exception(f"All wrappers failed, latest error: {error}")
|
||||
|
||||
def try_call_all(self, func: Callable[[W], T]) -> dict[str, T]:
|
||||
"""
|
||||
@@ -78,28 +86,33 @@ class WrapperHandler(Generic[W]):
|
||||
Raises:
|
||||
Exception: If all wrappers fail.
|
||||
"""
|
||||
log_info(f"{inspect.getsource(func).strip()} {inspect.getclosurevars(func).nonlocals}")
|
||||
|
||||
results = {}
|
||||
log_info(f"All wrappers: {[wrapper.__class__ for wrapper in self.wrappers]} - function {func}")
|
||||
for wrapper in self.wrappers:
|
||||
wrapper_name = wrapper.__class__.__name__
|
||||
try:
|
||||
result = func(wrapper)
|
||||
log_info(f"{wrapper_name} succeeded")
|
||||
results[wrapper.__class__] = result
|
||||
except Exception as e:
|
||||
log_warning(f"{wrapper} failed: {WrapperHandler.__concise_error(e)}")
|
||||
error = WrapperHandler.__concise_error(e)
|
||||
log_warning(f"{wrapper_name} failed: {error}")
|
||||
if not results:
|
||||
raise Exception("All wrappers failed")
|
||||
raise Exception(f"All wrappers failed, latest error: {error}")
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def __check(wrappers: list[W]) -> bool:
|
||||
return all(w.__class__ is type for w in wrappers)
|
||||
|
||||
@staticmethod
|
||||
def __concise_error(e: Exception) -> str:
|
||||
last_frame = traceback.extract_tb(e.__traceback__)[-1]
|
||||
return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]"
|
||||
|
||||
@staticmethod
|
||||
def build_wrappers(constructors: Iterable[Type[W]], try_per_wrapper: int = 3, retry_delay: int = 2) -> 'WrapperHandler[W]':
|
||||
def build_wrappers(constructors: Iterable[Type[W]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict | None = None) -> 'WrapperHandler[W]':
|
||||
"""
|
||||
Builds a WrapperHandler instance with the given wrapper constructors.
|
||||
It attempts to initialize each wrapper and logs a warning if any cannot be initialized.
|
||||
@@ -108,6 +121,7 @@ class WrapperHandler(Generic[W]):
|
||||
constructors (Iterable[Type[W]]): An iterable of wrapper classes to instantiate. e.g. [WrapperA, WrapperB]
|
||||
try_per_wrapper (int): Number of retries per wrapper before switching to the next.
|
||||
retry_delay (int): Delay in seconds between retries.
|
||||
kwargs (dict | None): Optional dictionary with keyword arguments common to all wrappers.
|
||||
Returns:
|
||||
WrapperHandler[W]: An instance of WrapperHandler with the initialized wrappers.
|
||||
Raises:
|
||||
@@ -118,7 +132,7 @@ class WrapperHandler(Generic[W]):
|
||||
result = []
|
||||
for wrapper_class in constructors:
|
||||
try:
|
||||
wrapper = wrapper_class()
|
||||
wrapper = wrapper_class(**(kwargs or {}))
|
||||
result.append(wrapper)
|
||||
except Exception as e:
|
||||
log_warning(f"{wrapper_class} cannot be initialized: {e}")
|
||||
|
||||
Reference in New Issue
Block a user