12 fix docs (#13)
* fix dependencies uv.lock * refactor test markers for clarity * refactor: clean up imports and remove unused files * refactor: remove unused agent files and clean up market API instructions * refactor: enhance wrapper initialization with keyword arguments and clean up tests * refactor: remove PublicBinanceAgent * refactor: aggregator - simplified MarketDataAggregator and related models to functions * refactor: update README and .env.example to reflect the latest changes to the project * refactor: simplify product info and price creation in YFinanceWrapper * refactor: remove get_all_products method from market API wrappers and update documentation * fix: environment variable assertions * refactor: remove status attribute from ProductInfo and update related methods to use timestamp_ms * feat: implement aggregate_history_prices function to calculate hourly price averages * refactor: update docker-compose and app.py for improved environment variable handling and compatibility * feat: add detailed market instructions and improve error handling in price aggregation methods * feat: add aggregated news retrieval methods for top headlines and latest news * refactor: improve error messages in WrapperHandler for better clarity * fix: correct quote currency extraction in create_product_info and remove debug prints from tests
This commit was merged in pull request #13.
This commit is contained in:
committed by
GitHub
parent
a8755913d8
commit
d2fbc0ceea
@@ -45,8 +45,9 @@ class TestBinance:
|
||||
assert isinstance(history, list)
|
||||
assert len(history) == 5
|
||||
for entry in history:
|
||||
assert hasattr(entry, 'time')
|
||||
assert hasattr(entry, 'timestamp_ms')
|
||||
assert hasattr(entry, 'close')
|
||||
assert hasattr(entry, 'high')
|
||||
assert entry.close > 0
|
||||
assert entry.high > 0
|
||||
assert entry.timestamp_ms > 0
|
||||
|
||||
@@ -47,8 +47,9 @@ class TestCoinBase:
|
||||
assert isinstance(history, list)
|
||||
assert len(history) == 5
|
||||
for entry in history:
|
||||
assert hasattr(entry, 'time')
|
||||
assert hasattr(entry, 'timestamp_ms')
|
||||
assert hasattr(entry, 'close')
|
||||
assert hasattr(entry, 'high')
|
||||
assert entry.close > 0
|
||||
assert entry.high > 0
|
||||
assert entry.timestamp_ms > 0
|
||||
|
||||
@@ -49,8 +49,9 @@ class TestCryptoCompare:
|
||||
assert isinstance(history, list)
|
||||
assert len(history) == 5
|
||||
for entry in history:
|
||||
assert hasattr(entry, 'time')
|
||||
assert hasattr(entry, 'timestamp_ms')
|
||||
assert hasattr(entry, 'close')
|
||||
assert hasattr(entry, 'high')
|
||||
assert entry.close > 0
|
||||
assert entry.high > 0
|
||||
assert entry.timestamp_ms > 0
|
||||
|
||||
@@ -5,7 +5,7 @@ from app.news import NewsApiWrapper
|
||||
|
||||
@pytest.mark.news
|
||||
@pytest.mark.api
|
||||
@pytest.mark.skipif(not os.getenv("NEWS_API_KEY"), reason="NEWS_API_KEY not set")
|
||||
@pytest.mark.skipif(not os.getenv("NEWS_API_KEY"), reason="NEWS_API_KEY not set in environment variables")
|
||||
class TestNewsAPI:
|
||||
|
||||
def test_news_api_initialization(self):
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import pytest
|
||||
from praw import Reddit
|
||||
from app.social.reddit import MAX_COMMENTS, RedditWrapper
|
||||
|
||||
@pytest.mark.social
|
||||
@pytest.mark.api
|
||||
@pytest.mark.skipif(not(os.getenv("REDDIT_API_CLIENT_ID")) or not os.getenv("REDDIT_API_CLIENT_SECRET"), reason="REDDIT_CLIENT_ID and REDDIT_API_CLIENT_SECRET not set in environment variables")
|
||||
class TestRedditWrapper:
|
||||
def test_initialization(self):
|
||||
wrapper = RedditWrapper()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import pytest
|
||||
from app.markets import YFinanceWrapper
|
||||
|
||||
@@ -14,17 +13,6 @@ class TestYFinance:
|
||||
assert hasattr(market, 'tool')
|
||||
assert market.tool is not None
|
||||
|
||||
def test_yfinance_get_product(self):
|
||||
market = YFinanceWrapper()
|
||||
product = market.get_product("AAPL")
|
||||
assert product is not None
|
||||
assert hasattr(product, 'symbol')
|
||||
assert product.symbol == "AAPL"
|
||||
assert hasattr(product, 'price')
|
||||
assert product.price > 0
|
||||
assert hasattr(product, 'status')
|
||||
assert product.status == "trading"
|
||||
|
||||
def test_yfinance_get_crypto_product(self):
|
||||
market = YFinanceWrapper()
|
||||
product = market.get_product("BTC")
|
||||
@@ -37,57 +25,32 @@ class TestYFinance:
|
||||
|
||||
def test_yfinance_get_products(self):
|
||||
market = YFinanceWrapper()
|
||||
products = market.get_products(["AAPL", "GOOGL"])
|
||||
products = market.get_products(["BTC", "ETH"])
|
||||
assert products is not None
|
||||
assert isinstance(products, list)
|
||||
assert len(products) == 2
|
||||
symbols = [p.symbol for p in products]
|
||||
assert "AAPL" in symbols
|
||||
assert "GOOGL" in symbols
|
||||
assert "BTC" in symbols
|
||||
assert "ETH" in symbols
|
||||
for product in products:
|
||||
assert hasattr(product, 'price')
|
||||
assert product.price > 0
|
||||
|
||||
def test_yfinance_get_all_products(self):
|
||||
market = YFinanceWrapper()
|
||||
products = market.get_all_products()
|
||||
assert products is not None
|
||||
assert isinstance(products, list)
|
||||
assert len(products) > 0
|
||||
# Dovrebbe contenere asset popolari
|
||||
symbols = [p.symbol for p in products]
|
||||
assert "AAPL" in symbols # Apple dovrebbe essere nella lista
|
||||
for product in products:
|
||||
assert hasattr(product, 'symbol')
|
||||
assert hasattr(product, 'price')
|
||||
|
||||
def test_yfinance_invalid_product(self):
|
||||
market = YFinanceWrapper()
|
||||
# Per YFinance, un prodotto invalido dovrebbe restituire un prodotto offline
|
||||
product = market.get_product("INVALIDSYMBOL123")
|
||||
assert product is not None
|
||||
assert product.status == "offline"
|
||||
with pytest.raises(Exception):
|
||||
_ = market.get_product("INVALIDSYMBOL123")
|
||||
|
||||
def test_yfinance_history(self):
|
||||
def test_yfinance_crypto_history(self):
|
||||
market = YFinanceWrapper()
|
||||
history = market.get_historical_prices("AAPL", limit=5)
|
||||
history = market.get_historical_prices("BTC", limit=5)
|
||||
assert history is not None
|
||||
assert isinstance(history, list)
|
||||
assert len(history) == 5
|
||||
for entry in history:
|
||||
assert hasattr(entry, 'time')
|
||||
assert hasattr(entry, 'timestamp_ms')
|
||||
assert hasattr(entry, 'close')
|
||||
assert hasattr(entry, 'high')
|
||||
assert entry.close > 0
|
||||
assert entry.high > 0
|
||||
|
||||
def test_yfinance_crypto_history(self):
|
||||
market = YFinanceWrapper()
|
||||
history = market.get_historical_prices("BTC", limit=3)
|
||||
assert history is not None
|
||||
assert isinstance(history, list)
|
||||
assert len(history) == 3
|
||||
for entry in history:
|
||||
assert hasattr(entry, 'time')
|
||||
assert hasattr(entry, 'close')
|
||||
assert entry.close > 0
|
||||
assert entry.timestamp_ms > 0
|
||||
|
||||
@@ -14,17 +14,20 @@ def pytest_configure(config:pytest.Config):
|
||||
|
||||
markers = [
|
||||
("slow", "marks tests as slow (deselect with '-m \"not slow\"')"),
|
||||
("limited", "marks tests that have limited execution due to API constraints"),
|
||||
|
||||
("api", "marks tests that require API access"),
|
||||
("market", "marks tests that use market data"),
|
||||
("news", "marks tests that use news"),
|
||||
("social", "marks tests that use social media"),
|
||||
("wrapper", "marks tests for wrapper handler"),
|
||||
|
||||
("tools", "marks tests for tools"),
|
||||
("aggregator", "marks tests for market data aggregator"),
|
||||
|
||||
("gemini", "marks tests that use Gemini model"),
|
||||
("ollama_gpt", "marks tests that use Ollama GPT model"),
|
||||
("ollama_qwen", "marks tests that use Ollama Qwen model"),
|
||||
("news", "marks tests that use news"),
|
||||
("social", "marks tests that use social media"),
|
||||
("limited", "marks tests that have limited execution due to API constraints"),
|
||||
("wrapper", "marks tests for wrapper handler"),
|
||||
("tools", "marks tests for tools"),
|
||||
("aggregator", "marks tests for market data aggregator"),
|
||||
]
|
||||
for marker in markers:
|
||||
line = f"{marker[0]}: {marker[1]}"
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import pytest
|
||||
from app.agents.market_agent import MarketToolkit
|
||||
from app.markets import MarketAPIsTool
|
||||
|
||||
@pytest.mark.limited # usa molte api calls e non voglio esaurire le chiavi api
|
||||
|
||||
@pytest.mark.tools
|
||||
@pytest.mark.market
|
||||
@pytest.mark.api
|
||||
class TestMarketAPIsTool:
|
||||
def test_wrapper_initialization(self):
|
||||
@@ -12,7 +11,6 @@ class TestMarketAPIsTool:
|
||||
assert market_wrapper is not None
|
||||
assert hasattr(market_wrapper, 'get_product')
|
||||
assert hasattr(market_wrapper, 'get_products')
|
||||
assert hasattr(market_wrapper, 'get_all_products')
|
||||
assert hasattr(market_wrapper, 'get_historical_prices')
|
||||
|
||||
def test_wrapper_capabilities(self):
|
||||
@@ -34,27 +32,6 @@ class TestMarketAPIsTool:
|
||||
assert hasattr(btc_product, 'price')
|
||||
assert btc_product.price > 0
|
||||
|
||||
def test_market_toolkit_integration(self):
|
||||
try:
|
||||
toolkit = MarketToolkit()
|
||||
assert toolkit is not None
|
||||
assert hasattr(toolkit, 'market_agent')
|
||||
assert toolkit.market_api is not None
|
||||
|
||||
tools = toolkit.tools
|
||||
assert len(tools) > 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"MarketToolkit test failed: {e}")
|
||||
# Non fail completamente - il toolkit potrebbe avere dipendenze specifiche
|
||||
|
||||
def test_provider_selection_mechanism(self):
|
||||
potential_providers = 0
|
||||
if os.getenv('CDP_API_KEY_NAME') and os.getenv('CDP_API_PRIVATE_KEY'):
|
||||
potential_providers += 1
|
||||
if os.getenv('CRYPTOCOMPARE_API_KEY'):
|
||||
potential_providers += 1
|
||||
|
||||
def test_error_handling(self):
|
||||
try:
|
||||
market_wrapper = MarketAPIsTool("USD")
|
||||
@@ -62,9 +39,3 @@ class TestMarketAPIsTool:
|
||||
assert fake_product is None or fake_product.price == 0
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def test_wrapper_currency_support(self):
|
||||
market_wrapper = MarketAPIsTool("USD")
|
||||
assert hasattr(market_wrapper, 'currency')
|
||||
assert isinstance(market_wrapper.currency, str)
|
||||
assert len(market_wrapper.currency) >= 3 # USD, EUR, etc.
|
||||
|
||||
@@ -2,7 +2,6 @@ import pytest
|
||||
from app.news import NewsAPIsTool
|
||||
|
||||
|
||||
@pytest.mark.limited
|
||||
@pytest.mark.tools
|
||||
@pytest.mark.news
|
||||
@pytest.mark.api
|
||||
@@ -34,10 +33,8 @@ class TestNewsAPITool:
|
||||
result = tool.wrapper_handler.try_call_all(lambda w: w.get_top_headlines(limit=2))
|
||||
assert isinstance(result, dict)
|
||||
assert len(result.keys()) > 0
|
||||
print("Results from providers:", result.keys())
|
||||
for provider, articles in result.items():
|
||||
for article in articles:
|
||||
print(provider, article.title)
|
||||
assert article.title is not None
|
||||
assert article.source is not None
|
||||
|
||||
@@ -46,9 +43,7 @@ class TestNewsAPITool:
|
||||
result = tool.wrapper_handler.try_call_all(lambda w: w.get_latest_news(query="crypto", limit=2))
|
||||
assert isinstance(result, dict)
|
||||
assert len(result.keys()) > 0
|
||||
print("Results from providers:", result.keys())
|
||||
for provider, articles in result.items():
|
||||
for article in articles:
|
||||
print(provider, article.title)
|
||||
assert article.title is not None
|
||||
assert article.source is not None
|
||||
|
||||
@@ -24,9 +24,7 @@ class TestSocialAPIsTool:
|
||||
result = tool.wrapper_handler.try_call_all(lambda w: w.get_top_crypto_posts(limit=2))
|
||||
assert isinstance(result, dict)
|
||||
assert len(result.keys()) > 0
|
||||
print("Results from providers:", result.keys())
|
||||
for provider, posts in result.items():
|
||||
for post in posts:
|
||||
print(provider, post.title)
|
||||
assert post.title is not None
|
||||
assert post.time is not None
|
||||
|
||||
120
tests/utils/test_market_aggregator.py
Normal file
120
tests/utils/test_market_aggregator.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import pytest
|
||||
from app.markets.base import ProductInfo, Price
|
||||
from app.utils.market_aggregation import aggregate_history_prices, aggregate_product_info
|
||||
|
||||
|
||||
@pytest.mark.aggregator
|
||||
@pytest.mark.market
|
||||
class TestMarketDataAggregator:
|
||||
|
||||
def __product(self, symbol: str, price: float, volume: float, currency: str) -> ProductInfo:
|
||||
prod = ProductInfo()
|
||||
prod.id=f"{symbol}-{currency}"
|
||||
prod.symbol=symbol
|
||||
prod.price=price
|
||||
prod.volume_24h=volume
|
||||
prod.quote_currency=currency
|
||||
return prod
|
||||
|
||||
def __price(self, timestamp_ms: int, high: float, low: float, open: float, close: float, volume: float) -> Price:
|
||||
price = Price()
|
||||
price.timestamp_ms = timestamp_ms
|
||||
price.high = high
|
||||
price.low = low
|
||||
price.open = open
|
||||
price.close = close
|
||||
price.volume = volume
|
||||
return price
|
||||
|
||||
def test_aggregate_product_info(self):
|
||||
products: dict[str, list[ProductInfo]] = {
|
||||
"Provider1": [self.__product("BTC", 50000.0, 1000.0, "USD")],
|
||||
"Provider2": [self.__product("BTC", 50100.0, 1100.0, "USD")],
|
||||
"Provider3": [self.__product("BTC", 49900.0, 900.0, "USD")],
|
||||
}
|
||||
|
||||
aggregated = aggregate_product_info(products)
|
||||
assert len(aggregated) == 1
|
||||
|
||||
info = aggregated[0]
|
||||
assert info is not None
|
||||
assert info.symbol == "BTC"
|
||||
|
||||
avg_weighted_price = (50000.0 * 1000.0 + 50100.0 * 1100.0 + 49900.0 * 900.0) / (1000.0 + 1100.0 + 900.0)
|
||||
assert info.price == pytest.approx(avg_weighted_price, rel=1e-3)
|
||||
assert info.volume_24h == pytest.approx(1000.0, rel=1e-3)
|
||||
assert info.quote_currency == "USD"
|
||||
|
||||
def test_aggregate_product_info_multiple_symbols(self):
|
||||
products = {
|
||||
"Provider1": [
|
||||
self.__product("BTC", 50000.0, 1000.0, "USD"),
|
||||
self.__product("ETH", 4000.0, 2000.0, "USD"),
|
||||
],
|
||||
"Provider2": [
|
||||
self.__product("BTC", 50100.0, 1100.0, "USD"),
|
||||
self.__product("ETH", 4050.0, 2100.0, "USD"),
|
||||
],
|
||||
}
|
||||
|
||||
aggregated = aggregate_product_info(products)
|
||||
assert len(aggregated) == 2
|
||||
|
||||
btc_info = next((p for p in aggregated if p.symbol == "BTC"), None)
|
||||
eth_info = next((p for p in aggregated if p.symbol == "ETH"), None)
|
||||
|
||||
assert btc_info is not None
|
||||
avg_weighted_price_btc = (50000.0 * 1000.0 + 50100.0 * 1100.0) / (1000.0 + 1100.0)
|
||||
assert btc_info.price == pytest.approx(avg_weighted_price_btc, rel=1e-3)
|
||||
assert btc_info.volume_24h == pytest.approx(1050.0, rel=1e-3)
|
||||
assert btc_info.quote_currency == "USD"
|
||||
|
||||
assert eth_info is not None
|
||||
avg_weighted_price_eth = (4000.0 * 2000.0 + 4050.0 * 2100.0) / (2000.0 + 2100.0)
|
||||
assert eth_info.price == pytest.approx(avg_weighted_price_eth, rel=1e-3)
|
||||
assert eth_info.volume_24h == pytest.approx(2050.0, rel=1e-3)
|
||||
assert eth_info.quote_currency == "USD"
|
||||
|
||||
def test_aggregate_product_info_with_no_data(self):
|
||||
products = {
|
||||
"Provider1": [],
|
||||
"Provider2": [],
|
||||
}
|
||||
aggregated = aggregate_product_info(products)
|
||||
assert len(aggregated) == 0
|
||||
|
||||
def test_aggregate_product_info_with_partial_data(self):
|
||||
products = {
|
||||
"Provider1": [self.__product("BTC", 50000.0, 1000.0, "USD")],
|
||||
"Provider2": [],
|
||||
}
|
||||
aggregated = aggregate_product_info(products)
|
||||
assert len(aggregated) == 1
|
||||
info = aggregated[0]
|
||||
assert info.symbol == "BTC"
|
||||
assert info.price == pytest.approx(50000.0, rel=1e-3)
|
||||
assert info.volume_24h == pytest.approx(1000.0, rel=1e-3)
|
||||
assert info.quote_currency == "USD"
|
||||
|
||||
def test_aggregate_history_prices(self):
|
||||
"""Test aggregazione di prezzi storici usando aggregate_history_prices"""
|
||||
|
||||
prices = {
|
||||
"Provider1": [
|
||||
self.__price(1685577600000, 50000.0, 49500.0, 49600.0, 49900.0, 150.0),
|
||||
self.__price(1685581200000, 50200.0, 49800.0, 50000.0, 50100.0, 200.0),
|
||||
],
|
||||
"Provider2": [
|
||||
self.__price(1685577600000, 50100.0, 49600.0, 49700.0, 50000.0, 180.0),
|
||||
self.__price(1685581200000, 50300.0, 49900.0, 50100.0, 50200.0, 220.0),
|
||||
],
|
||||
}
|
||||
|
||||
aggregated = aggregate_history_prices(prices)
|
||||
assert len(aggregated) == 2
|
||||
assert aggregated[0].timestamp_ms == 1685577600000
|
||||
assert aggregated[0].high == pytest.approx(50050.0, rel=1e-3)
|
||||
assert aggregated[0].low == pytest.approx(49550.0, rel=1e-3)
|
||||
assert aggregated[1].timestamp_ms == 1685581200000
|
||||
assert aggregated[1].high == pytest.approx(50250.0, rel=1e-3)
|
||||
assert aggregated[1].low == pytest.approx(49850.0, rel=1e-3)
|
||||
@@ -1,88 +0,0 @@
|
||||
import pytest
|
||||
from app.utils.market_data_aggregator import MarketDataAggregator
|
||||
from app.utils.aggregated_models import AggregatedProductInfo
|
||||
from app.markets.base import ProductInfo, Price
|
||||
|
||||
@pytest.mark.aggregator
|
||||
@pytest.mark.limited
|
||||
@pytest.mark.market
|
||||
@pytest.mark.api
|
||||
class TestMarketDataAggregator:
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test che il MarketDataAggregator si inizializzi correttamente"""
|
||||
aggregator = MarketDataAggregator()
|
||||
assert aggregator is not None
|
||||
assert aggregator.is_aggregation_enabled() == True
|
||||
|
||||
def test_aggregation_toggle(self):
|
||||
"""Test del toggle dell'aggregazione"""
|
||||
aggregator = MarketDataAggregator()
|
||||
|
||||
# Disabilita aggregazione
|
||||
aggregator.enable_aggregation(False)
|
||||
assert aggregator.is_aggregation_enabled() == False
|
||||
|
||||
# Riabilita aggregazione
|
||||
aggregator.enable_aggregation(True)
|
||||
assert aggregator.is_aggregation_enabled() == True
|
||||
|
||||
def test_aggregated_product_info_creation(self):
|
||||
"""Test creazione AggregatedProductInfo da fonti multiple"""
|
||||
|
||||
# Crea dati di esempio
|
||||
product1 = ProductInfo(
|
||||
id="BTC-USD",
|
||||
symbol="BTC-USD",
|
||||
price=50000.0,
|
||||
volume_24h=1000.0,
|
||||
status="active",
|
||||
quote_currency="USD"
|
||||
)
|
||||
|
||||
product2 = ProductInfo(
|
||||
id="BTC-USD",
|
||||
symbol="BTC-USD",
|
||||
price=50100.0,
|
||||
volume_24h=1100.0,
|
||||
status="active",
|
||||
quote_currency="USD"
|
||||
)
|
||||
|
||||
# Aggrega i prodotti
|
||||
aggregated = AggregatedProductInfo.from_multiple_sources([product1, product2])
|
||||
|
||||
assert aggregated.symbol == "BTC-USD"
|
||||
assert aggregated.price == pytest.approx(50050.0, rel=1e-3) # media tra 50000 e 50100
|
||||
assert aggregated.volume_24h == 50052.38095 # somma dei volumi
|
||||
assert aggregated.status == "active" # majority vote
|
||||
assert aggregated.id == "BTC-USD_AGG" # mapping_id con suffisso aggregazione
|
||||
|
||||
def test_confidence_calculation(self):
|
||||
"""Test del calcolo della confidence"""
|
||||
|
||||
product1 = ProductInfo(
|
||||
id="BTC-USD",
|
||||
symbol="BTC-USD",
|
||||
price=50000.0,
|
||||
volume_24h=1000.0,
|
||||
status="active",
|
||||
quote_currency="USD"
|
||||
)
|
||||
|
||||
product2 = ProductInfo(
|
||||
id="BTC-USD",
|
||||
symbol="BTC-USD",
|
||||
price=50100.0,
|
||||
volume_24h=1100.0,
|
||||
status="active",
|
||||
quote_currency="USD"
|
||||
)
|
||||
|
||||
aggregated = AggregatedProductInfo.from_multiple_sources([product1, product2])
|
||||
|
||||
# Verifica che ci siano metadati
|
||||
assert aggregated._metadata is not None
|
||||
assert len(aggregated._metadata.sources_used) > 0
|
||||
assert aggregated._metadata.aggregation_timestamp != ""
|
||||
# La confidence può essere 0.0 se ci sono fonti "unknown"
|
||||
@@ -54,7 +54,7 @@ class TestWrapperHandler:
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
handler.try_call(lambda w: w.do_something())
|
||||
assert "All wrappers failed after retries" in str(exc_info.value)
|
||||
assert "All wrappers failed" in str(exc_info.value)
|
||||
|
||||
def test_success_on_first_try(self):
|
||||
wrappers = [MockWrapper, FailingWrapper]
|
||||
@@ -121,7 +121,6 @@ class TestWrapperHandler:
|
||||
handler_all_fail.try_call_all(lambda w: w.do_something())
|
||||
assert "All wrappers failed" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_wrappers_with_parameters(self):
|
||||
wrappers = [FailingWrapperWithParameters, MockWrapperWithParameters]
|
||||
handler: WrapperHandler[MockWrapperWithParameters] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0)
|
||||
@@ -130,3 +129,24 @@ class TestWrapperHandler:
|
||||
assert result == "Success test and 42"
|
||||
assert handler.index == 1 # Should have switched to the second wrapper
|
||||
assert handler.retry_count == 0
|
||||
|
||||
def test_wrappers_with_parameters_all_fail(self):
|
||||
wrappers = [FailingWrapperWithParameters, FailingWrapperWithParameters]
|
||||
handler: WrapperHandler[MockWrapperWithParameters] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
handler.try_call(lambda w: w.do_something("test", 42))
|
||||
assert "All wrappers failed" in str(exc_info.value)
|
||||
|
||||
def test_try_call_all_with_parameters(self):
|
||||
wrappers = [FailingWrapperWithParameters, MockWrapperWithParameters]
|
||||
handler: WrapperHandler[MockWrapperWithParameters] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0)
|
||||
results = handler.try_call_all(lambda w: w.do_something("param", 99))
|
||||
assert results == {MockWrapperWithParameters: "Success param and 99"}
|
||||
|
||||
def test_try_call_all_with_parameters_all_fail(self):
|
||||
wrappers = [FailingWrapperWithParameters, FailingWrapperWithParameters]
|
||||
handler: WrapperHandler[MockWrapperWithParameters] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=1, retry_delay=0)
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
handler.try_call_all(lambda w: w.do_something("param", 99))
|
||||
assert "All wrappers failed" in str(exc_info.value)
|
||||
|
||||
Reference in New Issue
Block a user