Refactor WrapperHandler

- validation checks for initialization logic
- fix SocialAPIsTool
- fix RedditWrapper
This commit is contained in:
2025-10-01 11:05:44 +02:00
parent e4e7023c17
commit 73dcbbe12b
5 changed files with 41 additions and 9 deletions

View File

@@ -26,7 +26,7 @@ class SocialAPIsTool(SocialWrapper, Toolkit):
"""
wrappers = [RedditWrapper]
self.wrapper_handler: WrapperHandler[SocialWrapper] = WrapperHandler(wrappers)
self.wrapper_handler: WrapperHandler[SocialWrapper] = WrapperHandler.build_wrappers(wrappers)
Toolkit.__init__(
self,
@@ -38,7 +38,7 @@ class SocialAPIsTool(SocialWrapper, Toolkit):
# TODO Pensare se ha senso restituire i post da TUTTI i wrapper o solo dal primo che funziona
# la modifica è banale, basta usare try_call_all invece di try_call
def get_top_crypto_posts(self, limit:int = 5) -> list[SocialPost]:
def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]:
return self.wrapper_handler.try_call(lambda w: w.get_top_crypto_posts(limit))

View File

@@ -17,6 +17,7 @@ class SocialWrapper:
Base class for social media API wrappers.
All social media API wrappers should inherit from this class and implement the methods.
"""
def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]:
"""
Get top cryptocurrency-related posts, optionally limited by total.

View File

@@ -35,19 +35,19 @@ class RedditWrapper(SocialWrapper):
"""
def __init__(self):
self.client_id = os.getenv("REDDIT_API_CLIENT_ID")
assert self.client_id is not None, "REDDIT_API_CLIENT_ID environment variable is not set"
client_id = os.getenv("REDDIT_API_CLIENT_ID")
assert client_id is not None, "REDDIT_API_CLIENT_ID environment variable is not set"
self.client_secret = os.getenv("REDDIT_API_CLIENT_SECRET")
assert self.client_secret is not None, "REDDIT_API_CLIENT_SECRET environment variable is not set"
client_secret = os.getenv("REDDIT_API_CLIENT_SECRET")
assert client_secret is not None, "REDDIT_API_CLIENT_SECRET environment variable is not set"
self.tool = Reddit(
client_id=self.client_id,
client_secret=self.client_secret,
client_id=client_id,
client_secret=client_secret,
user_agent="upo-appAI",
)
def get_top_crypto_posts(self, limit:int = 5) -> list[SocialPost]:
def get_top_crypto_posts(self, limit: int = 5) -> list[SocialPost]:
subreddit = self.tool.subreddit("CryptoCurrency")
top_posts = subreddit.top(limit=limit, time_filter="week")
return [create_social_post(post) for post in top_posts]

View File

@@ -24,6 +24,8 @@ class WrapperHandler(Generic[W]):
try_per_wrapper (int): Number of retries per wrapper before switching to the next.
retry_delay (int): Delay in seconds between retries.
"""
assert not WrapperHandler.__check(wrappers), "All wrappers must be instances of their respective classes. Use `build_wrappers` to create the WrapperHandler."
self.wrappers = wrappers
self.retry_per_wrapper = try_per_wrapper
self.retry_delay = retry_delay
@@ -87,6 +89,10 @@ class WrapperHandler(Generic[W]):
raise Exception("All wrappers failed")
return results
@staticmethod
def __check(wrappers: list[W]) -> bool:
return all(w.__class__ is type for w in wrappers)
@staticmethod
def build_wrappers(constructors: Iterable[Type[W]], try_per_wrapper: int = 3, retry_delay: int = 2) -> 'WrapperHandler[W]':
"""
@@ -102,6 +108,8 @@ class WrapperHandler(Generic[W]):
Raises:
Exception: If no wrappers could be initialized.
"""
assert WrapperHandler.__check(constructors), f"All constructors must be classes. Received: {constructors}"
result = []
for wrapper_class in constructors:
try:

View File

@@ -25,6 +25,29 @@ class FailingWrapperWithParameters(MockWrapperWithParameters):
@pytest.mark.wrapper
class TestWrapperHandler:
def test_init_failing(self):
with pytest.raises(AssertionError) as exc_info:
WrapperHandler([MockWrapper, MockWrapper2])
assert exc_info.type == AssertionError
def test_init_failing_empty(self):
with pytest.raises(AssertionError) as exc_info:
WrapperHandler.build_wrappers([])
assert exc_info.type == AssertionError
def test_init_failing_with_instances(self):
with pytest.raises(AssertionError) as exc_info:
WrapperHandler.build_wrappers([MockWrapper(), MockWrapper2()])
assert exc_info.type == AssertionError
def test_init_not_failing(self):
handler = WrapperHandler.build_wrappers([MockWrapper, MockWrapper2])
assert handler is not None
assert len(handler.wrappers) == 2
handler = WrapperHandler([MockWrapper(), MockWrapper2()])
assert handler is not None
assert len(handler.wrappers) == 2
def test_all_wrappers_fail(self):
wrappers = [FailingWrapper, FailingWrapper]
handler: WrapperHandler[MockWrapper] = WrapperHandler.build_wrappers(wrappers, try_per_wrapper=2, retry_delay=0)