Spaces:
Running
Running
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import spacy | |
import google.generativeai as genai | |
import json | |
import os | |
import dotenv | |
dotenv.load_dotenv() | |
# Load spaCy for NER | |
nlp = spacy.load("en_core_web_sm") | |
# Load the trained ML model | |
model_path = "./results/checkpoint-753" # Replace with the actual path to your model | |
# tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small') | |
# tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False) | |
from transformers import DebertaV2Tokenizer | |
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-small') | |
model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
model.eval() | |
def setup_gemini(): | |
genai.configure(api_key=os.getenv("GEMINI_API")) | |
model = genai.GenerativeModel('gemini-pro') | |
return model | |
def predict_with_model(text): | |
"""Predict whether the news is real or fake using the ML model.""" | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_label = torch.argmax(probabilities, dim=-1).item() | |
return "FAKE" if predicted_label == 1 else "REAL" | |
def extract_entities(text): | |
"""Extract named entities from text using spaCy.""" | |
doc = nlp(text) | |
entities = [(ent.text, ent.label_) for ent in doc.ents] | |
return entities | |
def predict_news(text): | |
"""Predict whether the news is real or fake using the ML model.""" | |
# Predict with the ML model | |
prediction = predict_with_model(text) | |
return prediction | |
def analyze_content_gemini(model, text): | |
prompt = f"""Analyze this news text and return a JSON object with the following structure: | |
{{ | |
"gemini_analysis": {{ | |
"predicted_classification": "Real or Fake", | |
"confidence_score": "0-100", | |
"reasoning": ["point1", "point2"] | |
}}, | |
"text_classification": {{ | |
"category": "", | |
"writing_style": "Formal/Informal/Clickbait", | |
"target_audience": "", | |
"content_type": "news/opinion/editorial" | |
}}, | |
"sentiment_analysis": {{ | |
"primary_emotion": "", | |
"emotional_intensity": "1-10", | |
"sensationalism_level": "High/Medium/Low", | |
"bias_indicators": ["bias1", "bias2"], | |
"tone": {{"formality": "formal/informal", "style": "Professional/Emotional/Neutral"}}, | |
"emotional_triggers": ["trigger1", "trigger2"] | |
}}, | |
"entity_recognition": {{ | |
"source_credibility": "High/Medium/Low", | |
"people": ["person1", "person2"], | |
"organizations": ["org1", "org2"], | |
"locations": ["location1", "location2"], | |
"dates": ["date1", "date2"], | |
"statistics": ["stat1", "stat2"] | |
}}, | |
"context": {{ | |
"main_narrative": "", | |
"supporting_elements": ["element1", "element2"], | |
"key_claims": ["claim1", "claim2"], | |
"narrative_structure": "" | |
}}, | |
"fact_checking": {{ | |
"verifiable_claims": ["claim1", "claim2"], | |
"evidence_present": "Yes/No", | |
"fact_check_score": "0-100" | |
}} | |
}} | |
Analyze this text and return only the JSON response: {text}""" | |
response = model.generate_content(prompt) | |
try: | |
cleaned_text = response.text.strip() | |
if cleaned_text.startswith('```json'): | |
cleaned_text = cleaned_text[7:-3] | |
return json.loads(cleaned_text) | |
except json.JSONDecodeError: | |
return { | |
"gemini_analysis": { | |
"predicted_classification": "UNCERTAIN", | |
"confidence_score": "50", | |
"reasoning": ["Analysis failed to generate valid JSON"] | |
} | |
} | |
def clean_gemini_output(text): | |
"""Remove markdown formatting from Gemini output""" | |
text = text.replace('##', '') | |
text = text.replace('**', '') | |
return text | |
def get_gemini_analysis(text): | |
"""Get detailed content analysis from Gemini.""" | |
gemini_model = setup_gemini() | |
gemini_analysis = analyze_content_gemini(gemini_model, text) | |
return gemini_analysis | |
def main(): | |
print("Welcome to the News Classifier!") | |
print("Enter your news text below. Type 'Exit' to quit.") | |
while True: | |
news_text = input("\nEnter news text: ") | |
if news_text.lower() == 'exit': | |
print("Thank you for using the News Classifier!") | |
return | |
# Get ML prediction | |
prediction = predict_news(news_text) | |
print(f"\nML Analysis: {prediction}") | |
# Get Gemini analysis | |
print("\n=== Detailed Gemini Analysis ===") | |
gemini_result = get_gemini_analysis(news_text) | |
print(gemini_result) | |
if __name__ == "__main__": | |
main() | |