Spaces:
Sleeping
Sleeping
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional | |
| from praw import Reddit | |
| from pydantic import Field, PrivateAttr, SecretStr | |
| from pydantic_settings import BaseSettings | |
| from obsei.payload import TextPayload | |
| from obsei.misc.utils import ( | |
| DATETIME_STRING_PATTERN, | |
| DEFAULT_LOOKUP_PERIOD, | |
| convert_utc_time, | |
| text_from_html, | |
| ) | |
| from obsei.source.base_source import BaseSource, BaseSourceConfig | |
| class RedditCredInfo(BaseSettings): | |
| # Create credential at https://www.reddit.com/prefs/apps | |
| # Also refer https://praw.readthedocs.io/en/latest/getting_started/authentication.html | |
| # Currently Password Flow, Read Only Mode and Saved Refresh Token Mode are supported | |
| client_id: SecretStr = Field(None, env="reddit_client_id") | |
| client_secret: SecretStr = Field(None, env="reddit_client_secret") | |
| user_agent: str = "Test User Agent" | |
| redirect_uri: Optional[str] = None | |
| refresh_token: Optional[SecretStr] = Field(None, env="reddit_refresh_token") | |
| username: Optional[str] = Field(None, env="reddit_username") | |
| password: Optional[SecretStr] = Field(None, env="reddit_password") | |
| read_only: bool = True | |
| class RedditConfig(BaseSourceConfig): | |
| # This is done to avoid exposing member to API response | |
| _reddit_client: Reddit = PrivateAttr() | |
| TYPE: str = "Reddit" | |
| subreddits: List[str] | |
| post_ids: Optional[List[str]] = None | |
| lookup_period: Optional[str] = None | |
| include_post_meta: Optional[bool] = True | |
| post_meta_field: str = "post_meta" | |
| cred_info: Optional[RedditCredInfo] = Field(None) | |
| def __init__(self, **data: Any): | |
| super().__init__(**data) | |
| self.cred_info = self.cred_info or RedditCredInfo() | |
| self._reddit_client = Reddit( | |
| client_id=self.cred_info.client_id.get_secret_value(), | |
| client_secret=self.cred_info.client_secret.get_secret_value(), | |
| redirect_uri=self.cred_info.redirect_uri, | |
| user_agent=self.cred_info.user_agent, | |
| refresh_token=self.cred_info.refresh_token.get_secret_value() | |
| if self.cred_info.refresh_token | |
| else None, | |
| username=self.cred_info.username if self.cred_info.username else None, | |
| password=self.cred_info.password.get_secret_value() | |
| if self.cred_info.password | |
| else None, | |
| ) | |
| self._reddit_client.read_only = self.cred_info.read_only | |
| def get_reddit_client(self) -> Reddit: | |
| return self._reddit_client | |
| class RedditSource(BaseSource): | |
| NAME: str = "Reddit" | |
| def lookup(self, config: RedditConfig, **kwargs: Any) -> List[TextPayload]: # type: ignore[override] | |
| source_responses: List[TextPayload] = [] | |
| # Get data from state | |
| id: str = kwargs.get("id", None) | |
| state: Optional[Dict[str, Any]] = ( | |
| None | |
| if id is None or self.store is None | |
| else self.store.get_source_state(id) | |
| ) | |
| update_state: bool = True if id else False | |
| state = state or dict() | |
| subreddit_reference = config.get_reddit_client().subreddit( | |
| "+".join(config.subreddits) | |
| ) | |
| post_stream = subreddit_reference.stream.submissions(pause_after=-1) | |
| for post in post_stream: | |
| if post is None: | |
| break | |
| post_data = vars(post) | |
| post_id = post_data["id"] | |
| if config.post_ids and not config.post_ids.__contains__(post_id): | |
| continue | |
| post_stat: Dict[str, Any] = state.get(post_id, dict()) | |
| lookup_period: str = post_stat.get("since_time", config.lookup_period) | |
| lookup_period = lookup_period or DEFAULT_LOOKUP_PERIOD | |
| if len(lookup_period) <= 5: | |
| since_time = convert_utc_time(lookup_period) | |
| else: | |
| since_time = datetime.strptime(lookup_period, DATETIME_STRING_PATTERN) | |
| last_since_time: datetime = since_time | |
| since_id: Optional[str] = post_stat.get("since_comment_id", None) | |
| last_index = since_id | |
| state[post_id] = post_stat | |
| post.comment_sort = "new" | |
| post.comments.replace_more(limit=None) | |
| # top_level_comments only | |
| first_comment = True | |
| for comment in post.comments: | |
| comment_data = vars(comment) | |
| if config.include_post_meta: | |
| comment_data[config.post_meta_field] = post_data | |
| comment_time = datetime.utcfromtimestamp( | |
| int(comment_data["created_utc"]) | |
| ) | |
| comment_id = comment_data["id"] | |
| if comment_time < since_time: | |
| break | |
| if last_index and last_index == comment_id: | |
| break | |
| if last_since_time is None or last_since_time < comment_time: | |
| last_since_time = comment_time | |
| if last_index is None or first_comment: | |
| last_index = comment_id | |
| first_comment = False | |
| text = "".join(text_from_html(comment_data["body_html"])) | |
| source_responses.append( | |
| TextPayload( | |
| processed_text=text, meta=comment_data, source_name=self.NAME | |
| ) | |
| ) | |
| post_stat["since_time"] = last_since_time.strftime(DATETIME_STRING_PATTERN) | |
| post_stat["since_comment_id"] = last_index | |
| if update_state and self.store is not None: | |
| self.store.update_source_state(workflow_id=id, state=state) | |
| return source_responses | |