Refactor API wrapper initialization to streamline configuration handling

This commit is contained in:
2025-10-22 17:06:16 +02:00
parent 93d005c3e5
commit 803ef22fea
6 changed files with 28 additions and 85 deletions

View File

@@ -12,30 +12,16 @@ class MarketAPIsTool(MarketWrapper, Toolkit):
Providers can be configured in configs.yaml under api.market_providers. Providers can be configured in configs.yaml under api.market_providers.
""" """
# Mapping of wrapper names to wrapper classes
_WRAPPER_MAP = {
'BinanceWrapper': BinanceWrapper,
'YFinanceWrapper': YFinanceWrapper,
'CoinBaseWrapper': CoinBaseWrapper,
'CryptoCompareWrapper': CryptoCompareWrapper,
}
def __init__(self): def __init__(self):
""" """
Initialize the MarketAPIsTool with market API wrappers configured in configs.yaml. Initialize the MarketAPIsTool with market API wrappers configured in configs.yaml.
The order of wrappers is determined by the api.market_providers list in the configuration. The order of wrappers is determined by the api.market_providers list in the configuration.
""" """
config = AppConfig() config = AppConfig()
# Get wrapper classes based on configuration using the helper function
wrappers = WrapperHandler.filter_wrappers_by_config(
wrapper_map=self._WRAPPER_MAP,
provider_names=config.api.market_providers,
fallback_wrappers=[BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper]
)
self.handler = WrapperHandler.build_wrappers( self.handler = WrapperHandler.build_wrappers(
wrappers, constructors=[BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper],
filters=config.api.market_providers,
try_per_wrapper=config.api.retry_attempts, try_per_wrapper=config.api.retry_attempts,
retry_delay=config.api.retry_delay_seconds retry_delay=config.api.retry_delay_seconds
) )

View File

@@ -15,30 +15,16 @@ class NewsAPIsTool(NewsWrapper, Toolkit):
If no wrapper succeeds, an exception is raised. If no wrapper succeeds, an exception is raised.
""" """
# Mapping of wrapper names to wrapper classes
_WRAPPER_MAP = {
'GoogleNewsWrapper': GoogleNewsWrapper,
'DuckDuckGoWrapper': DuckDuckGoWrapper,
'NewsApiWrapper': NewsApiWrapper,
'CryptoPanicWrapper': CryptoPanicWrapper,
}
def __init__(self): def __init__(self):
""" """
Initialize the NewsAPIsTool with news API wrappers configured in configs.yaml. Initialize the NewsAPIsTool with news API wrappers configured in configs.yaml.
The order of wrappers is determined by the api.news_providers list in the configuration. The order of wrappers is determined by the api.news_providers list in the configuration.
""" """
config = AppConfig() config = AppConfig()
# Get wrapper classes based on configuration using the helper function
wrappers = WrapperHandler.filter_wrappers_by_config(
wrapper_map=self._WRAPPER_MAP,
provider_names=config.api.news_providers,
fallback_wrappers=[GoogleNewsWrapper, DuckDuckGoWrapper, NewsApiWrapper, CryptoPanicWrapper]
)
self.handler = WrapperHandler.build_wrappers( self.handler = WrapperHandler.build_wrappers(
wrappers, constructors=[NewsApiWrapper, GoogleNewsWrapper, CryptoPanicWrapper, DuckDuckGoWrapper],
filters=config.api.news_providers,
try_per_wrapper=config.api.retry_attempts, try_per_wrapper=config.api.retry_attempts,
retry_delay=config.api.retry_delay_seconds retry_delay=config.api.retry_delay_seconds
) )

View File

@@ -16,29 +16,16 @@ class SocialAPIsTool(SocialWrapper, Toolkit):
If no wrapper succeeds, an exception is raised. If no wrapper succeeds, an exception is raised.
""" """
# Mapping of wrapper names to wrapper classes
_WRAPPER_MAP = {
'RedditWrapper': RedditWrapper,
'XWrapper': XWrapper,
'ChanWrapper': ChanWrapper,
}
def __init__(self): def __init__(self):
""" """
Initialize the SocialAPIsTool with social media API wrappers configured in configs.yaml. Initialize the SocialAPIsTool with social media API wrappers configured in configs.yaml.
The order of wrappers is determined by the api.social_providers list in the configuration. The order of wrappers is determined by the api.social_providers list in the configuration.
""" """
config = AppConfig() config = AppConfig()
# Get wrapper classes based on configuration using the helper function
wrappers = WrapperHandler.filter_wrappers_by_config(
wrapper_map=self._WRAPPER_MAP,
provider_names=config.api.social_providers,
fallback_wrappers=[RedditWrapper, XWrapper, ChanWrapper]
)
self.handler = WrapperHandler.build_wrappers( self.handler = WrapperHandler.build_wrappers(
wrappers, constructors=[RedditWrapper, XWrapper, ChanWrapper],
filters=config.api.social_providers,
try_per_wrapper=config.api.retry_attempts, try_per_wrapper=config.api.retry_attempts,
retry_delay=config.api.retry_delay_seconds retry_delay=config.api.retry_delay_seconds
) )

View File

@@ -131,41 +131,19 @@ class WrapperHandler(Generic[WrapperType]):
return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]" return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]"
@staticmethod @staticmethod
def filter_wrappers_by_config( def build_wrappers(
wrapper_map: dict[str, type[WrapperClassType]], constructors: list[type[WrapperClassType]],
provider_names: list[str], filters: list[str] | None = None,
fallback_wrappers: list[type[WrapperClassType]] | None = None try_per_wrapper: int = 3,
) -> list[type[WrapperClassType]]: retry_delay: int = 2,
""" kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[WrapperClassType]':
Filters wrapper classes based on a list of provider names from configuration.
Args:
wrapper_map (dict[str, type[W]]): Dictionary mapping provider names to wrapper classes.
provider_names (list[str]): List of provider names from configuration.
fallback_wrappers (list[type[W]] | None): Optional fallback list if no providers configured.
Returns:
list[type[W]]: List of wrapper classes in the order specified by provider_names.
"""
wrappers: list[type[WrapperClassType]] = []
for provider_name in provider_names:
if provider_name in wrapper_map:
wrappers.append(wrapper_map[provider_name])
# Fallback to all wrappers if none configured
if not wrappers and fallback_wrappers:
wrappers = fallback_wrappers
return wrappers
@staticmethod
def build_wrappers(constructors: list[type[WrapperClassType]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[WrapperClassType]':
""" """
Builds a WrapperHandler instance with the given wrapper constructors. Builds a WrapperHandler instance with the given wrapper constructors.
It attempts to initialize each wrapper and logs a warning if any cannot be initialized. It attempts to initialize each wrapper and logs a warning if any cannot be initialized.
Only successfully initialized wrappers are included in the handler. Only successfully initialized wrappers are included in the handler.
Args: Args:
constructors (list[type[W]]): An iterable of wrapper classes to instantiate. e.g. [WrapperA, WrapperB] constructors (list[type[W]]): An iterable of wrapper classes to instantiate. e.g. [WrapperA, WrapperB]
filters (list[str] | None): Optional list of provider names to filter the constructors.
try_per_wrapper (int): Number of retries per wrapper before switching to the next. try_per_wrapper (int): Number of retries per wrapper before switching to the next.
retry_delay (int): Delay in seconds between retries. retry_delay (int): Delay in seconds between retries.
kwargs (dict | None): Optional dictionary with keyword arguments common to all wrappers. kwargs (dict | None): Optional dictionary with keyword arguments common to all wrappers.
@@ -176,8 +154,14 @@ class WrapperHandler(Generic[WrapperType]):
""" """
assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}" assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}"
# Order of wrappers is now determined by the order in filters
filters = filters or [c.__name__ for c in constructors]
wrappers = [c for name in filters for c in constructors if c.__name__ == name]
result: list[WrapperClassType] = [] result: list[WrapperClassType] = []
for wrapper_class in constructors: for wrapper_class in wrappers:
if filters and wrapper_class.__name__ not in filters:
continue
try: try:
wrapper = wrapper_class(**(kwargs or {})) wrapper = wrapper_class(**(kwargs or {}))
result.append(wrapper) result.append(wrapper)

View File

@@ -7,14 +7,14 @@ from app.api.tools import MarketAPIsTool
@pytest.mark.api @pytest.mark.api
class TestMarketAPIsTool: class TestMarketAPIsTool:
def test_wrapper_initialization(self): def test_wrapper_initialization(self):
market_wrapper = MarketAPIsTool("EUR") market_wrapper = MarketAPIsTool()
assert market_wrapper is not None assert market_wrapper is not None
assert hasattr(market_wrapper, 'get_product') assert hasattr(market_wrapper, 'get_product')
assert hasattr(market_wrapper, 'get_products') assert hasattr(market_wrapper, 'get_products')
assert hasattr(market_wrapper, 'get_historical_prices') assert hasattr(market_wrapper, 'get_historical_prices')
def test_wrapper_capabilities(self): def test_wrapper_capabilities(self):
market_wrapper = MarketAPIsTool("EUR") market_wrapper = MarketAPIsTool()
capabilities: list[str] = [] capabilities: list[str] = []
if hasattr(market_wrapper, 'get_product'): if hasattr(market_wrapper, 'get_product'):
capabilities.append('single_product') capabilities.append('single_product')
@@ -25,7 +25,7 @@ class TestMarketAPIsTool:
assert len(capabilities) > 0 assert len(capabilities) > 0
def test_market_data_retrieval(self): def test_market_data_retrieval(self):
market_wrapper = MarketAPIsTool("EUR") market_wrapper = MarketAPIsTool()
btc_product = market_wrapper.get_product("BTC") btc_product = market_wrapper.get_product("BTC")
assert btc_product is not None assert btc_product is not None
assert hasattr(btc_product, 'symbol') assert hasattr(btc_product, 'symbol')
@@ -34,7 +34,7 @@ class TestMarketAPIsTool:
def test_error_handling(self): def test_error_handling(self):
try: try:
market_wrapper = MarketAPIsTool("EUR") market_wrapper = MarketAPIsTool()
fake_product = market_wrapper.get_product("NONEXISTENT_CRYPTO_SYMBOL_12345") fake_product = market_wrapper.get_product("NONEXISTENT_CRYPTO_SYMBOL_12345")
assert fake_product is None or fake_product.price == 0 assert fake_product is None or fake_product.price == 0
except Exception as _: except Exception as _:

View File

@@ -19,7 +19,7 @@ class TestSocialAPIsTool:
assert post.title is not None assert post.title is not None
assert post.time is not None assert post.time is not None
def test_social_api_tool_get_top__all_results(self): def test_social_api_tool_get_top_all_results(self):
tool = SocialAPIsTool() tool = SocialAPIsTool()
result = tool.handler.try_call_all(lambda w: w.get_top_crypto_posts(limit=2)) result = tool.handler.try_call_all(lambda w: w.get_top_crypto_posts(limit=2))
assert isinstance(result, dict) assert isinstance(result, dict)