diff --git a/src/app/social/__init__.py b/src/app/social/__init__.py index de83e63..9ce3708 100644 --- a/src/app/social/__init__.py +++ b/src/app/social/__init__.py @@ -26,7 +26,7 @@ class SocialAPIsTool(SocialWrapper, Toolkit): """ wrappers = [RedditWrapper] - self.wrapper_handler: WrapperHandler[SocialWrapper] = WrapperHandler(wrappers) + self.wrapper_handler: WrapperHandler[SocialWrapper] = WrapperHandler.build_wrappers(wrappers) Toolkit.__init__( self, @@ -38,7 +38,7 @@ class SocialAPIsTool(SocialWrapper, Toolkit): # TODO Pensare se ha senso restituire i post da TUTTI i wrapper o solo dal primo che funziona # la modifica รจ banale, basta usare try_call_all invece di try_call - def get_top_crypto_posts(self, limit:int = 5) -> list[SocialPost]: + def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]: return self.wrapper_handler.try_call(lambda w: w.get_top_crypto_posts(limit)) diff --git a/src/app/social/base.py b/src/app/social/base.py index 1b66c1d..dd894f5 100644 --- a/src/app/social/base.py +++ b/src/app/social/base.py @@ -17,6 +17,7 @@ class SocialWrapper: Base class for social media API wrappers. All social media API wrappers should inherit from this class and implement the methods. """ + def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]: """ Get top cryptocurrency-related posts, optionally limited by total. diff --git a/src/app/social/reddit.py b/src/app/social/reddit.py index 730f862..8f3867f 100644 --- a/src/app/social/reddit.py +++ b/src/app/social/reddit.py @@ -35,19 +35,19 @@ class RedditWrapper(SocialWrapper): """ def __init__(self): - self.client_id = os.getenv("REDDIT_API_CLIENT_ID") - assert self.client_id is not None, "REDDIT_API_CLIENT_ID environment variable is not set" + client_id = os.getenv("REDDIT_API_CLIENT_ID") + assert client_id is not None, "REDDIT_API_CLIENT_ID environment variable is not set" - self.client_secret = os.getenv("REDDIT_API_CLIENT_SECRET") - assert self.client_secret is not None, "REDDIT_API_CLIENT_SECRET environment variable is not set" + client_secret = os.getenv("REDDIT_API_CLIENT_SECRET") + assert client_secret is not None, "REDDIT_API_CLIENT_SECRET environment variable is not set" self.tool = Reddit( - client_id=self.client_id, - client_secret=self.client_secret, + client_id=client_id, + client_secret=client_secret, user_agent="upo-appAI", ) - def get_top_crypto_posts(self, limit:int = 5) -> list[SocialPost]: + def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]: subreddit = self.tool.subreddit("CryptoCurrency") top_posts = subreddit.top(limit=limit, time_filter="week") return [create_social_post(post) for post in top_posts] diff --git a/src/app/utils/wrapper_handler.py b/src/app/utils/wrapper_handler.py index ecc0e11..7d16c6c 100644 --- a/src/app/utils/wrapper_handler.py +++ b/src/app/utils/wrapper_handler.py @@ -24,6 +24,8 @@ class WrapperHandler(Generic[W]): try_per_wrapper (int): Number of retries per wrapper before switching to the next. retry_delay (int): Delay in seconds between retries. """ + assert not WrapperHandler.__check(wrappers), "All wrappers must be instances of their respective classes. Use `build_wrappers` to create the WrapperHandler." + self.wrappers = wrappers self.retry_per_wrapper = try_per_wrapper self.retry_delay = retry_delay @@ -87,6 +89,10 @@ class WrapperHandler(Generic[W]): raise Exception("All wrappers failed") return results + @staticmethod + def __check(wrappers: list[W]) -> bool: + return all(w.__class__ is type for w in wrappers) + @staticmethod def build_wrappers(constructors: Iterable[Type[W]], try_per_wrapper: int = 3, retry_delay: int = 2) -> 'WrapperHandler[W]': """ @@ -102,6 +108,8 @@ class WrapperHandler(Generic[W]): Raises: Exception: If no wrappers could be initialized. """ + assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}" + result = [] for wrapper_class in constructors: try: diff --git a/tests/utils/test_wrapper_handler.py b/tests/utils/test_wrapper_handler.py index d95d928..154d3dc 100644 --- a/tests/utils/test_wrapper_handler.py +++ b/tests/utils/test_wrapper_handler.py @@ -25,6 +25,29 @@ class FailingWrapperWithParameters(MockWrapperWithParameters): @pytest.mark.wrapper class TestWrapperHandler: + def test_init_failing(self): + with pytest.raises(AssertionError) as exc_info: + WrapperHandler([MockWrapper, MockWrapper2]) + assert exc_info.type == AssertionError + + def test_init_failing_empty(self): + with pytest.raises(AssertionError) as exc_info: + WrapperHandler.build_wrappers([]) + assert exc_info.type == AssertionError + + def test_init_failing_with_instances(self): + with pytest.raises(AssertionError) as exc_info: + WrapperHandler.build_wrappers([MockWrapper(), MockWrapper2()]) + assert exc_info.type == AssertionError + + def test_init_not_failing(self): + handler = WrapperHandler.build_wrappers([MockWrapper, MockWrapper2]) + assert handler is not None + assert len(handler.wrappers) == 2 + handler = WrapperHandler([MockWrapper(), MockWrapper2()]) + assert handler is not None + assert len(handler.wrappers) == 2 + def test_all_wrappers_fail(self): wrappers = [FailingWrapper, FailingWrapper] handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0)