|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import time |
|
from typing import Any, Dict, List, Union |
|
|
|
from requests.exceptions import RequestException |
|
|
|
from camel.toolkits import FunctionTool |
|
from camel.toolkits.base import BaseToolkit |
|
|
|
|
|
class RedditToolkit(BaseToolkit): |
|
r"""A class representing a toolkit for Reddit operations. |
|
|
|
This toolkit provides methods to interact with the Reddit API, allowing |
|
users to collect top posts, perform sentiment analysis on comments, and |
|
track keyword discussions across multiple subreddits. |
|
|
|
Attributes: |
|
retries (int): Number of retries for API requests in case of failure. |
|
delay (int): Delay between retries in seconds. |
|
reddit (Reddit): An instance of the Reddit client. |
|
""" |
|
|
|
def __init__(self, retries: int = 3, delay: int = 0): |
|
r"""Initializes the RedditToolkit with the specified number of retries |
|
and delay. |
|
|
|
Args: |
|
retries (int): Number of times to retry the request in case of |
|
failure. Defaults to `3`. |
|
delay (int): Time in seconds to wait between retries. Defaults to |
|
`0`. |
|
""" |
|
from praw import Reddit |
|
|
|
self.retries = retries |
|
self.delay = delay |
|
|
|
self.client_id = os.environ.get("REDDIT_CLIENT_ID", "") |
|
self.client_secret = os.environ.get("REDDIT_CLIENT_SECRET", "") |
|
self.user_agent = os.environ.get("REDDIT_USER_AGENT", "") |
|
|
|
self.reddit = Reddit( |
|
client_id=self.client_id, |
|
client_secret=self.client_secret, |
|
user_agent=self.user_agent, |
|
request_timeout=30, |
|
) |
|
|
|
def _retry_request(self, func, *args, **kwargs): |
|
r"""Retries a function in case of network-related errors. |
|
|
|
Args: |
|
func (callable): The function to be retried. |
|
*args: Arguments to pass to the function. |
|
**kwargs: Keyword arguments to pass to the function. |
|
|
|
Returns: |
|
Any: The result of the function call if successful. |
|
|
|
Raises: |
|
RequestException: If all retry attempts fail. |
|
""" |
|
for attempt in range(self.retries): |
|
try: |
|
return func(*args, **kwargs) |
|
except RequestException as e: |
|
print(f"Attempt {attempt + 1}/{self.retries} failed: {e}") |
|
if attempt < self.retries - 1: |
|
time.sleep(self.delay) |
|
else: |
|
raise |
|
|
|
def collect_top_posts( |
|
self, |
|
subreddit_name: str, |
|
post_limit: int = 5, |
|
comment_limit: int = 5, |
|
) -> Union[List[Dict[str, Any]], str]: |
|
r"""Collects the top posts and their comments from a specified |
|
subreddit. |
|
|
|
Args: |
|
subreddit_name (str): The name of the subreddit to collect posts |
|
from. |
|
post_limit (int): The maximum number of top posts to collect. |
|
Defaults to `5`. |
|
comment_limit (int): The maximum number of top comments to collect |
|
per post. Defaults to `5`. |
|
|
|
Returns: |
|
Union[List[Dict[str, Any]], str]: A list of dictionaries, each |
|
containing the post title and its top comments if success. |
|
String warming if credentials are not set. |
|
""" |
|
if not all([self.client_id, self.client_secret, self.user_agent]): |
|
return ( |
|
"Reddit API credentials are not set. " |
|
"Please set the environment variables." |
|
) |
|
|
|
subreddit = self._retry_request(self.reddit.subreddit, subreddit_name) |
|
top_posts = self._retry_request(subreddit.top, limit=post_limit) |
|
data = [] |
|
|
|
for post in top_posts: |
|
post_data = { |
|
"Post Title": post.title, |
|
"Comments": [ |
|
{"Comment Body": comment.body, "Upvotes": comment.score} |
|
for comment in self._retry_request( |
|
lambda post=post: list(post.comments) |
|
)[:comment_limit] |
|
], |
|
} |
|
data.append(post_data) |
|
time.sleep(self.delay) |
|
|
|
return data |
|
|
|
def perform_sentiment_analysis( |
|
self, data: List[Dict[str, Any]] |
|
) -> List[Dict[str, Any]]: |
|
r"""Performs sentiment analysis on the comments collected from Reddit |
|
posts. |
|
|
|
Args: |
|
data (List[Dict[str, Any]]): A list of dictionaries containing |
|
Reddit post data and comments. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: The original data with an added 'Sentiment |
|
Score' for each comment. |
|
""" |
|
from textblob import TextBlob |
|
|
|
for item in data: |
|
|
|
item["Sentiment Score"] = TextBlob( |
|
item["Comment Body"] |
|
).sentiment.polarity |
|
|
|
return data |
|
|
|
def track_keyword_discussions( |
|
self, |
|
subreddits: List[str], |
|
keywords: List[str], |
|
post_limit: int = 10, |
|
comment_limit: int = 10, |
|
sentiment_analysis: bool = False, |
|
) -> Union[List[Dict[str, Any]], str]: |
|
r"""Tracks discussions about specific keywords in specified subreddits. |
|
|
|
Args: |
|
subreddits (List[str]): A list of subreddit names to search within. |
|
keywords (List[str]): A list of keywords to track in the subreddit |
|
discussions. |
|
post_limit (int): The maximum number of top posts to collect per |
|
subreddit. Defaults to `10`. |
|
comment_limit (int): The maximum number of top comments to collect |
|
per post. Defaults to `10`. |
|
sentiment_analysis (bool): If True, performs sentiment analysis on |
|
the comments. Defaults to `False`. |
|
|
|
Returns: |
|
Union[List[Dict[str, Any]], str]: A list of dictionaries |
|
containing the subreddit name, post title, comment body, and |
|
upvotes for each comment that contains the specified keywords |
|
if success. String warming if credentials are not set. |
|
""" |
|
if not all([self.client_id, self.client_secret, self.user_agent]): |
|
return ( |
|
"Reddit API credentials are not set. " |
|
"Please set the environment variables." |
|
) |
|
|
|
data = [] |
|
|
|
for subreddit_name in subreddits: |
|
subreddit = self._retry_request( |
|
self.reddit.subreddit, subreddit_name |
|
) |
|
top_posts = self._retry_request(subreddit.top, limit=post_limit) |
|
|
|
for post in top_posts: |
|
for comment in self._retry_request( |
|
lambda post=post: list(post.comments) |
|
)[:comment_limit]: |
|
|
|
if any( |
|
keyword.lower() in comment.body.lower() |
|
for keyword in keywords |
|
): |
|
comment_data = { |
|
"Subreddit": subreddit_name, |
|
"Post Title": post.title, |
|
"Comment Body": comment.body, |
|
"Upvotes": comment.score, |
|
} |
|
data.append(comment_data) |
|
|
|
time.sleep(self.delay) |
|
if sentiment_analysis: |
|
data = self.perform_sentiment_analysis(data) |
|
return data |
|
|
|
def get_tools(self) -> List[FunctionTool]: |
|
r"""Returns a list of FunctionTool objects representing the |
|
functions in the toolkit. |
|
|
|
Returns: |
|
List[FunctionTool]: A list of FunctionTool objects for the |
|
toolkit methods. |
|
""" |
|
return [ |
|
FunctionTool(self.collect_top_posts), |
|
FunctionTool(self.perform_sentiment_analysis), |
|
FunctionTool(self.track_keyword_discussions), |
|
] |
|
|