diff --git a/src/app/api/tools/market_tool.py b/src/app/api/tools/market_tool.py index 7d7e8cd..e47fc9f 100644 --- a/src/app/api/tools/market_tool.py +++ b/src/app/api/tools/market_tool.py @@ -12,30 +12,16 @@ class MarketAPIsTool(MarketWrapper, Toolkit): Providers can be configured in configs.yaml under api.market_providers. """ - # Mapping of wrapper names to wrapper classes - _WRAPPER_MAP = { - 'BinanceWrapper': BinanceWrapper, - 'YFinanceWrapper': YFinanceWrapper, - 'CoinBaseWrapper': CoinBaseWrapper, - 'CryptoCompareWrapper': CryptoCompareWrapper, - } - def __init__(self): """ Initialize the MarketAPIsTool with market API wrappers configured in configs.yaml. The order of wrappers is determined by the api.market_providers list in the configuration. """ config = AppConfig() - - # Get wrapper classes based on configuration using the helper function - wrappers = WrapperHandler.filter_wrappers_by_config( - wrapper_map=self._WRAPPER_MAP, - provider_names=config.api.market_providers, - fallback_wrappers=[BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper] - ) - + self.handler = WrapperHandler.build_wrappers( - wrappers, + constructors=[BinanceWrapper, YFinanceWrapper, CoinBaseWrapper, CryptoCompareWrapper], + filters=config.api.market_providers, try_per_wrapper=config.api.retry_attempts, retry_delay=config.api.retry_delay_seconds ) diff --git a/src/app/api/tools/news_tool.py b/src/app/api/tools/news_tool.py index f9558ed..eddf48d 100644 --- a/src/app/api/tools/news_tool.py +++ b/src/app/api/tools/news_tool.py @@ -15,30 +15,16 @@ class NewsAPIsTool(NewsWrapper, Toolkit): If no wrapper succeeds, an exception is raised. """ - # Mapping of wrapper names to wrapper classes - _WRAPPER_MAP = { - 'GoogleNewsWrapper': GoogleNewsWrapper, - 'DuckDuckGoWrapper': DuckDuckGoWrapper, - 'NewsApiWrapper': NewsApiWrapper, - 'CryptoPanicWrapper': CryptoPanicWrapper, - } - def __init__(self): """ Initialize the NewsAPIsTool with news API wrappers configured in configs.yaml. The order of wrappers is determined by the api.news_providers list in the configuration. """ config = AppConfig() - - # Get wrapper classes based on configuration using the helper function - wrappers = WrapperHandler.filter_wrappers_by_config( - wrapper_map=self._WRAPPER_MAP, - provider_names=config.api.news_providers, - fallback_wrappers=[GoogleNewsWrapper, DuckDuckGoWrapper, NewsApiWrapper, CryptoPanicWrapper] - ) - + self.handler = WrapperHandler.build_wrappers( - wrappers, + constructors=[NewsApiWrapper, GoogleNewsWrapper, CryptoPanicWrapper, DuckDuckGoWrapper], + filters=config.api.news_providers, try_per_wrapper=config.api.retry_attempts, retry_delay=config.api.retry_delay_seconds ) diff --git a/src/app/api/tools/social_tool.py b/src/app/api/tools/social_tool.py index b44a969..ab346ca 100644 --- a/src/app/api/tools/social_tool.py +++ b/src/app/api/tools/social_tool.py @@ -16,29 +16,16 @@ class SocialAPIsTool(SocialWrapper, Toolkit): If no wrapper succeeds, an exception is raised. """ - # Mapping of wrapper names to wrapper classes - _WRAPPER_MAP = { - 'RedditWrapper': RedditWrapper, - 'XWrapper': XWrapper, - 'ChanWrapper': ChanWrapper, - } - def __init__(self): """ Initialize the SocialAPIsTool with social media API wrappers configured in configs.yaml. The order of wrappers is determined by the api.social_providers list in the configuration. """ config = AppConfig() - - # Get wrapper classes based on configuration using the helper function - wrappers = WrapperHandler.filter_wrappers_by_config( - wrapper_map=self._WRAPPER_MAP, - provider_names=config.api.social_providers, - fallback_wrappers=[RedditWrapper, XWrapper, ChanWrapper] - ) - + self.handler = WrapperHandler.build_wrappers( - wrappers, + constructors=[RedditWrapper, XWrapper, ChanWrapper], + filters=config.api.social_providers, try_per_wrapper=config.api.retry_attempts, retry_delay=config.api.retry_delay_seconds ) diff --git a/src/app/api/wrapper_handler.py b/src/app/api/wrapper_handler.py index 1b38f25..00aafa2 100644 --- a/src/app/api/wrapper_handler.py +++ b/src/app/api/wrapper_handler.py @@ -131,41 +131,19 @@ class WrapperHandler(Generic[WrapperType]): return f"{e} [\"{last_frame.filename}\", line {last_frame.lineno}]" @staticmethod - def filter_wrappers_by_config( - wrapper_map: dict[str, type[WrapperClassType]], - provider_names: list[str], - fallback_wrappers: list[type[WrapperClassType]] | None = None - ) -> list[type[WrapperClassType]]: - """ - Filters wrapper classes based on a list of provider names from configuration. - - Args: - wrapper_map (dict[str, type[W]]): Dictionary mapping provider names to wrapper classes. - provider_names (list[str]): List of provider names from configuration. - fallback_wrappers (list[type[W]] | None): Optional fallback list if no providers configured. - - Returns: - list[type[W]]: List of wrapper classes in the order specified by provider_names. - """ - wrappers: list[type[WrapperClassType]] = [] - for provider_name in provider_names: - if provider_name in wrapper_map: - wrappers.append(wrapper_map[provider_name]) - - # Fallback to all wrappers if none configured - if not wrappers and fallback_wrappers: - wrappers = fallback_wrappers - - return wrappers - - @staticmethod - def build_wrappers(constructors: list[type[WrapperClassType]], try_per_wrapper: int = 3, retry_delay: int = 2, kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[WrapperClassType]': + def build_wrappers( + constructors: list[type[WrapperClassType]], + filters: list[str] | None = None, + try_per_wrapper: int = 3, + retry_delay: int = 2, + kwargs: dict[str, Any] | None = None) -> 'WrapperHandler[WrapperClassType]': """ Builds a WrapperHandler instance with the given wrapper constructors. It attempts to initialize each wrapper and logs a warning if any cannot be initialized. Only successfully initialized wrappers are included in the handler. Args: constructors (list[type[W]]): An iterable of wrapper classes to instantiate. e.g. [WrapperA, WrapperB] + filters (list[str] | None): Optional list of provider names to filter the constructors. try_per_wrapper (int): Number of retries per wrapper before switching to the next. retry_delay (int): Delay in seconds between retries. kwargs (dict | None): Optional dictionary with keyword arguments common to all wrappers. @@ -176,8 +154,14 @@ class WrapperHandler(Generic[WrapperType]): """ assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}" + # Order of wrappers is now determined by the order in filters + filters = filters or [c.__name__ for c in constructors] + wrappers = [c for name in filters for c in constructors if c.__name__ == name] + result: list[WrapperClassType] = [] - for wrapper_class in constructors: + for wrapper_class in wrappers: + if filters and wrapper_class.__name__ not in filters: + continue try: wrapper = wrapper_class(**(kwargs or {})) result.append(wrapper) diff --git a/tests/tools/test_market_tool.py b/tests/tools/test_market_tool.py index ea90bf2..0787a0b 100644 --- a/tests/tools/test_market_tool.py +++ b/tests/tools/test_market_tool.py @@ -7,14 +7,14 @@ from app.api.tools import MarketAPIsTool @pytest.mark.api class TestMarketAPIsTool: def test_wrapper_initialization(self): - market_wrapper = MarketAPIsTool("EUR") + market_wrapper = MarketAPIsTool() assert market_wrapper is not None assert hasattr(market_wrapper, 'get_product') assert hasattr(market_wrapper, 'get_products') assert hasattr(market_wrapper, 'get_historical_prices') def test_wrapper_capabilities(self): - market_wrapper = MarketAPIsTool("EUR") + market_wrapper = MarketAPIsTool() capabilities: list[str] = [] if hasattr(market_wrapper, 'get_product'): capabilities.append('single_product') @@ -25,7 +25,7 @@ class TestMarketAPIsTool: assert len(capabilities) > 0 def test_market_data_retrieval(self): - market_wrapper = MarketAPIsTool("EUR") + market_wrapper = MarketAPIsTool() btc_product = market_wrapper.get_product("BTC") assert btc_product is not None assert hasattr(btc_product, 'symbol') @@ -34,7 +34,7 @@ class TestMarketAPIsTool: def test_error_handling(self): try: - market_wrapper = MarketAPIsTool("EUR") + market_wrapper = MarketAPIsTool() fake_product = market_wrapper.get_product("NONEXISTENT_CRYPTO_SYMBOL_12345") assert fake_product is None or fake_product.price == 0 except Exception as _: diff --git a/tests/tools/test_socials_tool.py b/tests/tools/test_socials_tool.py index c021a90..72c6681 100644 --- a/tests/tools/test_socials_tool.py +++ b/tests/tools/test_socials_tool.py @@ -19,7 +19,7 @@ class TestSocialAPIsTool: assert post.title is not None assert post.time is not None - def test_social_api_tool_get_top__all_results(self): + def test_social_api_tool_get_top_all_results(self): tool = SocialAPIsTool() result = tool.handler.try_call_all(lambda w: w.get_top_crypto_posts(limit=2)) assert isinstance(result, dict)