Spaces:
Running
Running
import logging | |
from typing import Dict, Any, List, Optional | |
from transformers import pipeline, AutoTokenizer | |
import numpy as np | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
logger = logging.getLogger(__name__) | |
class HeadlineAnalyzer: | |
def __init__(self, use_ai: bool = True, model_registry: Optional[Any] = None): | |
""" | |
Initialize the analyzers for headline analysis. | |
Args: | |
use_ai: Boolean indicating whether to use AI-powered analysis (True) or traditional analysis (False) | |
model_registry: Optional shared model registry for better performance | |
""" | |
self.use_ai = use_ai | |
self.llm_available = False | |
self.model_registry = model_registry | |
if use_ai: | |
try: | |
if model_registry and model_registry.is_available: | |
# Use shared models | |
self.nli_pipeline = model_registry.nli | |
self.zero_shot = model_registry.zero_shot | |
self.tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli") | |
self.llm_available = True | |
logger.info("Using shared model pipelines for headline analysis") | |
else: | |
# Initialize own pipelines | |
self.nli_pipeline = pipeline( | |
"text-classification", | |
model="roberta-large-mnli", | |
batch_size=16 | |
) | |
self.zero_shot = pipeline( | |
"zero-shot-classification", | |
model="facebook/bart-large-mnli", | |
device=-1, | |
batch_size=8 | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli") | |
self.llm_available = True | |
logger.info("Initialized dedicated model pipelines for headline analysis") | |
self.max_length = 512 | |
except Exception as e: | |
logger.warning(f"Failed to initialize LLM pipelines: {str(e)}") | |
self.llm_available = False | |
else: | |
logger.info("Initializing headline analyzer in traditional mode") | |
def _split_content(self, headline: str, content: str) -> List[str]: | |
"""Split content into sections that fit within token limit.""" | |
content_words = content.split() | |
sections = [] | |
current_section = [] | |
# Account for headline and [SEP] token in the max length | |
headline_tokens = len(self.tokenizer.encode(headline)) | |
sep_tokens = len(self.tokenizer.encode("[SEP]")) - 2 | |
max_content_tokens = self.max_length - headline_tokens - sep_tokens | |
# Process words into sections with 4000 character chunks | |
current_text = "" | |
for word in content_words: | |
if len(current_text) + len(word) + 1 <= 4000: | |
current_text += " " + word | |
else: | |
sections.append(current_text.strip()) | |
current_text = word | |
if current_text: | |
sections.append(current_text.strip()) | |
return sections | |
def _analyze_section(self, headline: str, section: str) -> Dict[str, Any]: | |
"""Analyze a single section for headline accuracy and sensationalism.""" | |
try: | |
logger.info("\n" + "-"*30) | |
logger.info("ANALYZING SECTION") | |
logger.info("-"*30) | |
logger.info(f"Headline: {headline}") | |
logger.info(f"Section length: {len(section)} characters") | |
# Download NLTK data if needed | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt') | |
sentences = sent_tokenize(section) | |
logger.info(f"Found {len(sentences)} sentences in section") | |
if not sentences: | |
logger.warning("No sentences found in section") | |
return { | |
"accuracy_score": 50.0, | |
"flagged_phrases": [], | |
"detailed_scores": { | |
"nli": {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0}, | |
"sensationalism": {"factual reporting": 0.5, "accurate headline": 0.5} | |
} | |
} | |
# Categories for sensationalism check | |
sensationalism_categories = [ | |
"clickbait", | |
"sensationalized", | |
"misleading", | |
"factual reporting", | |
"accurate headline" | |
] | |
logger.info("Checking headline for sensationalism...") | |
sensationalism_result = self.zero_shot( | |
headline, | |
sensationalism_categories, | |
multi_label=True | |
) | |
sensationalism_scores = { | |
label: score | |
for label, score in zip(sensationalism_result['labels'], sensationalism_result['scores']) | |
} | |
logger.info(f"Sensationalism scores: {sensationalism_scores}") | |
# Filter relevant sentences (longer than 20 chars) | |
relevant_sentences = [s.strip() for s in sentences if len(s.strip()) > 20] | |
logger.info(f"Found {len(relevant_sentences)} relevant sentences after filtering") | |
if not relevant_sentences: | |
logger.warning("No relevant sentences found in section") | |
return { | |
"accuracy_score": 50.0, | |
"flagged_phrases": [], | |
"detailed_scores": { | |
"nli": {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0}, | |
"sensationalism": sensationalism_scores | |
} | |
} | |
# Process sentences in batches for contradiction/support | |
nli_scores = [] | |
flagged_phrases = [] | |
batch_size = 8 | |
logger.info("Processing sentences for contradictions...") | |
for i in range(0, len(relevant_sentences), batch_size): | |
batch = relevant_sentences[i:i+batch_size] | |
batch_inputs = [f"{headline} [SEP] {sentence}" for sentence in batch] | |
try: | |
# Get NLI scores for batch | |
batch_results = self.nli_pipeline(batch_inputs, top_k=None) | |
if not isinstance(batch_results, list): | |
batch_results = [batch_results] | |
for sentence, result in zip(batch, batch_results): | |
scores = {item['label']: item['score'] for item in result} | |
nli_scores.append(scores) | |
# Flag contradictory content with lower threshold | |
if scores.get('CONTRADICTION', 0) > 0.3: # Lowered threshold | |
logger.info(f"Found contradictory sentence (score: {scores['CONTRADICTION']:.2f}): {sentence}") | |
flagged_phrases.append({ | |
'text': sentence, | |
'type': 'Contradiction', | |
'score': scores['CONTRADICTION'], | |
'highlight': f"[CONTRADICTION] (Score: {round(scores['CONTRADICTION'] * 100, 1)}%) \"{sentence}\"" | |
}) | |
# Flag highly sensationalized content | |
if sensationalism_scores.get('sensationalized', 0) > 0.6 or sensationalism_scores.get('clickbait', 0) > 0.6: | |
logger.info(f"Found sensationalized content: {sentence}") | |
flagged_phrases.append({ | |
'text': sentence, | |
'type': 'Sensationalized', | |
'score': max(sensationalism_scores.get('sensationalized', 0), sensationalism_scores.get('clickbait', 0)), | |
'highlight': f"[SENSATIONALIZED] \"{sentence}\"" | |
}) | |
except Exception as batch_error: | |
logger.warning(f"Batch processing error: {str(batch_error)}") | |
continue | |
# Calculate aggregate scores with validation | |
if not nli_scores: | |
logger.warning("No NLI scores available") | |
avg_scores = {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0} | |
else: | |
try: | |
avg_scores = { | |
label: float(np.mean([ | |
score.get(label, 0.0) | |
for score in nli_scores | |
])) | |
for label in ['ENTAILMENT', 'CONTRADICTION', 'NEUTRAL'] | |
} | |
logger.info(f"Average NLI scores: {avg_scores}") | |
except Exception as agg_error: | |
logger.error(f"Error aggregating NLI scores: {str(agg_error)}") | |
avg_scores = {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0} | |
# Calculate headline accuracy score with validation | |
try: | |
accuracy_components = { | |
'entailment': avg_scores.get('ENTAILMENT', 0.0) * 0.4, | |
'non_contradiction': (1 - avg_scores.get('CONTRADICTION', 0.0)) * 0.3, | |
'non_sensational': ( | |
sensationalism_scores.get('factual reporting', 0.0) + | |
sensationalism_scores.get('accurate headline', 0.0) | |
) * 0.15, | |
'non_clickbait': ( | |
1 - sensationalism_scores.get('clickbait', 0.0) - | |
sensationalism_scores.get('sensationalized', 0.0) | |
) * 0.15 | |
} | |
logger.info(f"Accuracy components: {accuracy_components}") | |
accuracy_score = sum(accuracy_components.values()) * 100 | |
# Validate final score | |
if np.isnan(accuracy_score) or not np.isfinite(accuracy_score): | |
logger.warning("Invalid accuracy score calculated, using default") | |
accuracy_score = 50.0 | |
else: | |
accuracy_score = float(accuracy_score) | |
logger.info(f"Final accuracy score: {accuracy_score:.1f}") | |
except Exception as score_error: | |
logger.error(f"Error calculating accuracy score: {str(score_error)}") | |
accuracy_score = 50.0 | |
# Sort and limit flagged phrases | |
sorted_phrases = sorted( | |
flagged_phrases, | |
key=lambda x: x['score'], | |
reverse=True | |
) | |
unique_phrases = [] | |
seen = set() | |
for phrase in sorted_phrases: | |
if phrase['text'] not in seen: | |
unique_phrases.append(phrase) | |
seen.add(phrase['text']) | |
if len(unique_phrases) >= 5: | |
break | |
logger.info(f"Final number of flagged phrases: {len(unique_phrases)}") | |
return { | |
"accuracy_score": accuracy_score, | |
"flagged_phrases": unique_phrases, | |
"detailed_scores": { | |
"nli": avg_scores, | |
"sensationalism": sensationalism_scores | |
} | |
} | |
except Exception as e: | |
logger.error(f"Section analysis failed: {str(e)}") | |
return { | |
"accuracy_score": 50.0, | |
"flagged_phrases": [], | |
"detailed_scores": { | |
"nli": {"ENTAILMENT": 0.0, "CONTRADICTION": 0.0, "NEUTRAL": 1.0}, | |
"sensationalism": {} | |
} | |
} | |
def _analyze_traditional(self, headline: str, content: str) -> Dict[str, Any]: | |
"""Traditional headline analysis method.""" | |
try: | |
# Download NLTK data if needed | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt') | |
# Basic metrics | |
headline_words = set(headline.lower().split()) | |
content_words = set(content.lower().split()) | |
# Calculate word overlap | |
overlap_words = headline_words.intersection(content_words) | |
overlap_score = len(overlap_words) / len(headline_words) if headline_words else 0 | |
# Check for clickbait patterns | |
clickbait_patterns = [ | |
"you won't believe", | |
"shocking", | |
"mind blowing", | |
"amazing", | |
"incredible", | |
"unbelievable", | |
"must see", | |
"click here", | |
"find out", | |
"what happens next" | |
] | |
clickbait_count = sum(1 for pattern in clickbait_patterns if pattern in headline.lower()) | |
clickbait_penalty = clickbait_count * 10 # 10% penalty per clickbait phrase | |
# Calculate final score (0-100) | |
base_score = overlap_score * 100 | |
final_score = max(0, min(100, base_score - clickbait_penalty)) | |
# Find potentially misleading phrases | |
flagged_phrases = [] | |
sentences = sent_tokenize(content) | |
for sentence in sentences: | |
# Flag sentences that directly contradict headline words | |
sentence_words = set(sentence.lower().split()) | |
if len(headline_words.intersection(sentence_words)) > 2: | |
flagged_phrases.append(sentence.strip()) | |
# Flag sentences with clickbait patterns | |
if any(pattern in sentence.lower() for pattern in clickbait_patterns): | |
flagged_phrases.append(sentence.strip()) | |
return { | |
"headline_vs_content_score": round(final_score, 1), | |
"flagged_phrases": list(set(flagged_phrases))[:5] # Limit to top 5 unique phrases | |
} | |
except Exception as e: | |
logger.error(f"Traditional analysis failed: {str(e)}") | |
return { | |
"headline_vs_content_score": 0, | |
"flagged_phrases": [] | |
} | |
def analyze(self, headline: str, content: str) -> Dict[str, Any]: | |
"""Analyze how well the headline matches the content.""" | |
try: | |
logger.info("\n" + "="*50) | |
logger.info("HEADLINE ANALYSIS STARTED") | |
logger.info("="*50) | |
if not headline.strip() or not content.strip(): | |
logger.warning("Empty headline or content provided") | |
return { | |
"headline_vs_content_score": 0, | |
"flagged_phrases": [] | |
} | |
# Use LLM analysis if available and enabled | |
if self.use_ai and self.llm_available: | |
logger.info("Using LLM analysis for headline") | |
# Split content if needed | |
sections = self._split_content(headline, content) | |
section_results = [] | |
# Analyze each section | |
for section in sections: | |
result = self._analyze_section(headline, section) | |
section_results.append(result) | |
# Aggregate results across sections | |
accuracy_scores = [r['accuracy_score'] for r in section_results] | |
final_score = np.mean(accuracy_scores) | |
# Combine and deduplicate flagged phrases | |
all_phrases = [] | |
for result in section_results: | |
if 'flagged_phrases' in result: | |
all_phrases.extend(result['flagged_phrases']) | |
# Sort by score and get unique phrases | |
sorted_phrases = sorted(all_phrases, key=lambda x: x['score'], reverse=True) | |
unique_phrases = [] | |
seen = set() | |
for phrase in sorted_phrases: | |
if phrase['text'] not in seen: | |
unique_phrases.append(phrase) | |
seen.add(phrase['text']) | |
if len(unique_phrases) >= 5: | |
break | |
return { | |
"headline_vs_content_score": round(final_score, 1), | |
"flagged_phrases": unique_phrases | |
} | |
else: | |
# Use traditional analysis | |
logger.info("Using traditional headline analysis") | |
return self._analyze_traditional(headline, content) | |
except Exception as e: | |
logger.error(f"Headline analysis failed: {str(e)}") | |
return { | |
"headline_vs_content_score": 0, | |
"flagged_phrases": [] | |
} |