142 lines
4.8 KiB
Python
142 lines
4.8 KiB
Python
"""Reddit data collection service using PRAW."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Dict, Iterable, List, Optional
|
|
|
|
import praw
|
|
from praw.models import Comment, Submission
|
|
|
|
from .utils import (
|
|
NormalizedItem,
|
|
ServiceError,
|
|
ServiceWarning,
|
|
ensure_timezone,
|
|
sanitize_text,
|
|
)
|
|
|
|
|
|
TIME_FILTER_MAP = {
|
|
"24h": "day",
|
|
"7d": "week",
|
|
"30d": "month",
|
|
}
|
|
|
|
|
|
def _iter_submissions(subreddit: praw.models.Subreddit, query: str, limit: int, time_filter: str) -> Iterable[Submission]:
|
|
return subreddit.search(query=query, sort="new", time_filter=time_filter, limit=limit * 3)
|
|
|
|
|
|
def _iter_comments(submission: Submission) -> Iterable[Comment]:
|
|
submission.comments.replace_more(limit=0)
|
|
return submission.comments.list()
|
|
|
|
|
|
def _normalize_submission(submission: Submission) -> NormalizedItem:
|
|
created = datetime.fromtimestamp(submission.created_utc, tz=timezone.utc)
|
|
return NormalizedItem(
|
|
source="reddit",
|
|
id=submission.id,
|
|
url=f"https://www.reddit.com{submission.permalink}",
|
|
author=str(submission.author) if submission.author else None,
|
|
timestamp=ensure_timezone(created),
|
|
text=f"{submission.title}\n\n{submission.selftext or ''}",
|
|
meta={
|
|
"score": submission.score,
|
|
"num_comments": submission.num_comments,
|
|
"subreddit": submission.subreddit.display_name,
|
|
"type": "submission",
|
|
},
|
|
)
|
|
|
|
|
|
def _normalize_comment(comment: Comment, submission: Submission) -> NormalizedItem:
|
|
created = datetime.fromtimestamp(comment.created_utc, tz=timezone.utc)
|
|
return NormalizedItem(
|
|
source="reddit",
|
|
id=comment.id,
|
|
url=f"https://www.reddit.com{comment.permalink}",
|
|
author=str(comment.author) if comment.author else None,
|
|
timestamp=ensure_timezone(created),
|
|
text=comment.body,
|
|
meta={
|
|
"score": comment.score,
|
|
"subreddit": submission.subreddit.display_name,
|
|
"type": "comment",
|
|
"submission_title": submission.title,
|
|
},
|
|
)
|
|
|
|
|
|
def fetch_mentions(
|
|
brand: str,
|
|
credentials: Dict[str, str],
|
|
limit: int = 25,
|
|
date_filter: str = "7d",
|
|
min_upvotes: int = 0,
|
|
) -> List[NormalizedItem]:
|
|
"""Fetch recent Reddit submissions/comments mentioning the brand."""
|
|
client_id = credentials.get("client_id")
|
|
client_secret = credentials.get("client_secret")
|
|
user_agent = credentials.get("user_agent")
|
|
|
|
if not all([client_id, client_secret, user_agent]):
|
|
raise ServiceWarning("Reddit credentials are missing. Provide them in the sidebar to enable this source.")
|
|
|
|
try:
|
|
reddit = praw.Reddit(
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
user_agent=user_agent,
|
|
)
|
|
reddit.read_only = True
|
|
except Exception as exc: # noqa: BLE001
|
|
raise ServiceError(f"Failed to initialize Reddit client: {exc}") from exc
|
|
|
|
time_filter = TIME_FILTER_MAP.get(date_filter.lower(), "week")
|
|
subreddit = reddit.subreddit("all")
|
|
results: List[NormalizedItem] = []
|
|
seen_ids: set[str] = set()
|
|
try:
|
|
for submission in _iter_submissions(subreddit, query=brand, limit=limit, time_filter=time_filter):
|
|
if submission.id in seen_ids:
|
|
continue
|
|
if submission.score < min_upvotes:
|
|
continue
|
|
normalized_submission = _normalize_submission(submission)
|
|
normalized_submission["text"] = sanitize_text(normalized_submission["text"])
|
|
if normalized_submission["text"]:
|
|
results.append(normalized_submission)
|
|
seen_ids.add(submission.id)
|
|
if len(results) >= limit:
|
|
break
|
|
|
|
# Fetch comments mentioning the brand
|
|
match_count = 0
|
|
for comment in _iter_comments(submission):
|
|
if brand.lower() not in (comment.body or "").lower():
|
|
continue
|
|
if comment.score < min_upvotes:
|
|
continue
|
|
normalized_comment = _normalize_comment(comment, submission)
|
|
normalized_comment["text"] = sanitize_text(normalized_comment["text"])
|
|
if not normalized_comment["text"]:
|
|
continue
|
|
if normalized_comment["id"] in seen_ids:
|
|
continue
|
|
results.append(normalized_comment)
|
|
seen_ids.add(normalized_comment["id"])
|
|
match_count += 1
|
|
if len(results) >= limit:
|
|
break
|
|
if len(results) >= limit:
|
|
break
|
|
# Respect rate limits
|
|
if match_count:
|
|
time.sleep(1)
|
|
except Exception as exc: # noqa: BLE001
|
|
raise ServiceError(f"Error while fetching Reddit data: {exc}") from exc
|
|
return results
|