Spaces:
Running
Running
import streamlit as st | |
import random | |
# Page configuration | |
st.set_page_config( | |
page_title="Nexus NLP News Classifier" | |
) | |
import pandas as pd | |
from final import * | |
from pydantic import BaseModel | |
import plotly.graph_objects as go | |
# Update the initialize_models function | |
def initialize_models(): | |
try: | |
nlp = spacy.load("en_core_web_sm") | |
except: | |
spacy.cli.download("en_core_web_sm") | |
nlp = spacy.load("en_core_web_sm") | |
model_path = "./results/checkpoint-753" | |
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-small') | |
model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
model.eval() | |
knowledge_graph = load_knowledge_graph() | |
return nlp, tokenizer, model, knowledge_graph | |
class NewsInput(BaseModel): | |
text: str | |
# def generate_knowledge_graph_viz(text, nlp, tokenizer, model): | |
# kg_builder = KnowledgeGraphBuilder() | |
# # Get prediction | |
# prediction, _ = predict_with_model(text, tokenizer, model) | |
# is_fake = prediction == "FAKE" | |
# # Update knowledge graph | |
# kg_builder.update_knowledge_graph(text, not is_fake, nlp) | |
# # Randomly select subset of edges (e.g. 10% of edges) | |
# edges = list(kg_builder.knowledge_graph.edges()) | |
# selected_edges = random.sample(edges, k=int(len(edges) * 0.3)) | |
# # Create a new graph with selected edges | |
# selected_graph = nx.DiGraph() | |
# selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True)) | |
# selected_graph.add_edges_from(selected_edges) | |
# pos = nx.spring_layout(selected_graph) | |
# edge_trace = go.Scatter( | |
# x=[], y=[], | |
# line=dict( | |
# width=2, | |
# color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' | |
# ), | |
# hoverinfo='none', | |
# mode='lines' | |
# ) | |
# # Create visualization | |
# pos = nx.spring_layout(kg_builder.knowledge_graph) | |
# edge_trace = go.Scatter( | |
# x=[], y=[], | |
# line=dict( | |
# width=2, | |
# color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' | |
# ), | |
# hoverinfo='none', | |
# mode='lines' | |
# ) | |
# node_trace = go.Scatter( | |
# x=[], y=[], | |
# mode='markers+text', | |
# hoverinfo='text', | |
# textposition='top center', | |
# marker=dict( | |
# size=15, | |
# color='white', | |
# line=dict(width=2, color='black') | |
# ), | |
# text=[] | |
# ) | |
# # Add edges | |
# for edge in selected_graph.edges(): | |
# x0, y0 = pos[edge[0]] | |
# x1, y1 = pos[edge[1]] | |
# edge_trace['x'] += (x0, x1, None) | |
# edge_trace['y'] += (y0, y1, None) | |
# # Add nodes | |
# for node in kg_builder.knowledge_graph.nodes(): | |
# x, y = pos[node] | |
# node_trace['x'] += (x,) | |
# node_trace['y'] += (y,) | |
# node_trace['text'] += (node,) | |
# fig = go.Figure( | |
# data=[edge_trace, node_trace], | |
# layout=go.Layout( | |
# showlegend=False, | |
# hovermode='closest', | |
# margin=dict(b=0,l=0,r=0,t=0), | |
# xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
# yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
# plot_bgcolor='rgba(0,0,0,0)', | |
# paper_bgcolor='rgba(0,0,0,0)' | |
# ) | |
# ) | |
# return fig | |
def generate_knowledge_graph_viz(text, nlp, tokenizer, model): | |
kg_builder = KnowledgeGraphBuilder() | |
# Get prediction | |
prediction, _ = predict_with_model(text, tokenizer, model) | |
is_fake = prediction == "FAKE" | |
# Update knowledge graph | |
kg_builder.update_knowledge_graph(text, not is_fake, nlp) | |
# Get all edges from the knowledge graph | |
all_edges = list(kg_builder.knowledge_graph.edges()) | |
total_edges = len(all_edges) | |
# Select only 50% of edges to display | |
display_edge_count = int(total_edges * 0.5) | |
display_edges = random.sample(all_edges, k=min(display_edge_count, total_edges)) | |
# Determine how many edges should be the opposite color (15% of displayed edges) | |
opposite_color_count = int(len(display_edges) * 0.15) | |
# Randomly select which edges will have the opposite color | |
opposite_color_edges = set(random.sample(display_edges, k=opposite_color_count)) | |
# Create a new graph with selected edges | |
selected_graph = nx.DiGraph() | |
selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True)) | |
selected_graph.add_edges_from(display_edges) | |
pos = nx.spring_layout(selected_graph) | |
# Create two edge traces - one for dominant color, one for opposite color | |
dominant_edge_trace = go.Scatter( | |
x=[], y=[], | |
line=dict( | |
width=2, | |
color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' | |
), | |
hoverinfo='none', | |
mode='lines' | |
) | |
opposite_edge_trace = go.Scatter( | |
x=[], y=[], | |
line=dict( | |
width=2, | |
color='rgba(0,255,0,0.7)' if is_fake else 'rgba(255,0,0,0.7)' | |
), | |
hoverinfo='none', | |
mode='lines' | |
) | |
node_trace = go.Scatter( | |
x=[], y=[], | |
mode='markers+text', | |
hoverinfo='text', | |
textposition='top center', | |
marker=dict( | |
size=15, | |
color='white', | |
line=dict(width=2, color='black') | |
), | |
text=[] | |
) | |
# Add edges with appropriate colors | |
for edge in display_edges: | |
x0, y0 = pos[edge[0]] | |
x1, y1 = pos[edge[1]] | |
if edge in opposite_color_edges: | |
opposite_edge_trace['x'] += (x0, x1, None) | |
opposite_edge_trace['y'] += (y0, y1, None) | |
else: | |
dominant_edge_trace['x'] += (x0, x1, None) | |
dominant_edge_trace['y'] += (y0, y1, None) | |
# Add nodes | |
for node in selected_graph.nodes(): | |
x, y = pos[node] | |
node_trace['x'] += (x,) | |
node_trace['y'] += (y,) | |
node_trace['text'] += (node,) | |
fig = go.Figure( | |
data=[dominant_edge_trace, opposite_edge_trace, node_trace], | |
layout=go.Layout( | |
showlegend=False, | |
hovermode='closest', | |
margin=dict(b=0,l=0,r=0,t=0), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
plot_bgcolor='rgba(0,0,0,0)', | |
paper_bgcolor='rgba(0,0,0,0)' | |
) | |
) | |
return fig | |
def generate_knowledge_graph_viz(text, nlp, tokenizer, model): | |
kg_builder = KnowledgeGraphBuilder() | |
# Get prediction | |
prediction, _ = predict_with_model(text, tokenizer, model) | |
is_fake = prediction == "FAKE" | |
# Update knowledge graph | |
kg_builder.update_knowledge_graph(text, not is_fake, nlp) | |
# Get all edges from the knowledge graph | |
all_edges = list(kg_builder.knowledge_graph.edges()) | |
total_edges = len(all_edges) | |
# Select only 60% of edges to display (0.3 + 0.15 + 0.15) | |
display_edge_count = int(total_edges * 0.6) | |
display_edges = random.sample(all_edges, k=min(display_edge_count, total_edges)) | |
# Determine edge counts for each color | |
primary_color_count = int(total_edges * 0.3) # 30% primary color (green for real, red for fake) | |
opposite_color_count = int(total_edges * 0.15) # 15% opposite color | |
orange_color_count = int(total_edges * 0.15) # 15% orange | |
# Ensure we don't exceed the number of display edges | |
total_colored = primary_color_count + opposite_color_count + orange_color_count | |
if total_colored > len(display_edges): | |
ratio = len(display_edges) / total_colored | |
primary_color_count = int(primary_color_count * ratio) | |
opposite_color_count = int(opposite_color_count * ratio) | |
orange_color_count = int(orange_color_count * ratio) | |
# Shuffle display edges to ensure random distribution | |
random.shuffle(display_edges) | |
# Assign colors to edges | |
primary_color_edges = set(display_edges[:primary_color_count]) | |
opposite_color_edges = set(display_edges[primary_color_count:primary_color_count+opposite_color_count]) | |
orange_color_edges = set(display_edges[primary_color_count+opposite_color_count: | |
primary_color_count+opposite_color_count+orange_color_count]) | |
# Create a new graph with selected edges | |
selected_graph = nx.DiGraph() | |
selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True)) | |
selected_graph.add_edges_from(display_edges) | |
pos = nx.spring_layout(selected_graph) | |
# Create three edge traces - primary, opposite, and orange | |
primary_edge_trace = go.Scatter( | |
x=[], y=[], | |
line=dict( | |
width=2, | |
color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' # Red if fake, green if real | |
), | |
hoverinfo='none', | |
mode='lines' | |
) | |
opposite_edge_trace = go.Scatter( | |
x=[], y=[], | |
line=dict( | |
width=2, | |
color='rgba(0,255,0,0.7)' if is_fake else 'rgba(255,0,0,0.7)' # Green if fake, red if real | |
), | |
hoverinfo='none', | |
mode='lines' | |
) | |
orange_edge_trace = go.Scatter( | |
x=[], y=[], | |
line=dict( | |
width=2, | |
color='rgba(255,165,0,0.7)' # Orange | |
), | |
hoverinfo='none', | |
mode='lines' | |
) | |
node_trace = go.Scatter( | |
x=[], y=[], | |
mode='markers+text', | |
hoverinfo='text', | |
textposition='top center', | |
marker=dict( | |
size=15, | |
color='white', | |
line=dict(width=2, color='black') | |
), | |
text=[] | |
) | |
# Add edges with appropriate colors | |
for edge in display_edges: | |
x0, y0 = pos[edge[0]] | |
x1, y1 = pos[edge[1]] | |
if edge in primary_color_edges: | |
primary_edge_trace['x'] += (x0, x1, None) | |
primary_edge_trace['y'] += (y0, y1, None) | |
elif edge in opposite_color_edges: | |
opposite_edge_trace['x'] += (x0, x1, None) | |
opposite_edge_trace['y'] += (y0, y1, None) | |
elif edge in orange_color_edges: | |
orange_edge_trace['x'] += (x0, x1, None) | |
orange_edge_trace['y'] += (y0, y1, None) | |
# Add nodes | |
for node in selected_graph.nodes(): | |
x, y = pos[node] | |
node_trace['x'] += (x,) | |
node_trace['y'] += (y,) | |
node_trace['text'] += (node,) | |
fig = go.Figure( | |
data=[primary_edge_trace, opposite_edge_trace, orange_edge_trace, node_trace], | |
layout=go.Layout( | |
showlegend=False, | |
hovermode='closest', | |
margin=dict(b=0,l=0,r=0,t=0), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
plot_bgcolor='rgba(0,0,0,0)', | |
paper_bgcolor='rgba(0,0,0,0)' | |
) | |
) | |
return fig | |
# Streamlit UI | |
def main(): | |
st.title("Nexus NLP News Classifier") | |
st.write("Enter news text below to analyze its authenticity") | |
# Initialize models | |
nlp, tokenizer, model, knowledge_graph = initialize_models() | |
# Text input area | |
news_text = st.text_area("News Text", height=200) | |
if st.button("Analyze"): | |
if news_text: | |
with st.spinner("Analyzing..."): | |
# Get predictions from all models | |
ml_prediction, ml_confidence = predict_with_model(news_text, tokenizer, model) | |
kg_prediction, kg_confidence = predict_with_knowledge_graph(news_text, knowledge_graph, nlp) | |
# Update knowledge graph | |
update_knowledge_graph(news_text, ml_prediction == "REAL", knowledge_graph, nlp) | |
# Get Gemini analysis | |
# Get Gemini analysis with retries | |
max_retries = 10 | |
retry_count = 0 | |
gemini_result = None | |
while retry_count < max_retries: | |
try: | |
gemini_model = setup_gemini() | |
gemini_result = analyze_content_gemini(gemini_model, news_text) | |
# Check if we got valid results | |
if gemini_result and gemini_result.get('gemini_analysis'): | |
break | |
except Exception as e: | |
st.error(f"Gemini API error: {str(e)}") | |
print(f"Gemini error: {str(e)}") | |
retry_count += 1 | |
import time | |
time.sleep(1) # Add a 1-second delay between retries | |
# Use default values if all retries failed | |
if not gemini_result: | |
gemini_result = { | |
"gemini_analysis": { | |
"predicted_classification": "UNCERTAIN", | |
"confidence_score": "50", | |
"reasoning": ["Analysis temporarily unavailable"] | |
} | |
} | |
# Display metrics in columns | |
col1 = st.columns(1)[0] | |
with col1: | |
st.subheader("ML Model and Knowedge Graph Analysis") | |
st.metric("Prediction", ml_prediction) | |
st.metric("Confidence", f"{ml_confidence:.2f}%") | |
# with col2: | |
# st.subheader("Knowledge Graph Analysis") | |
# st.metric("Prediction", kg_prediction) | |
# st.metric("Confidence", f"{kg_confidence:.2f}%") | |
# with col3: | |
# st.subheader("Gemini Analysis") | |
# gemini_pred = gemini_result["gemini_analysis"]["predicted_classification"] | |
# gemini_conf = gemini_result["gemini_analysis"]["confidence_score"] | |
# st.metric("Prediction", gemini_pred) | |
# st.metric("Confidence", f"{gemini_conf}%") | |
# Single expander for all analysis details | |
with st.expander("Click here to get Detailed Analysis"): | |
try: | |
# Analysis Reasoning | |
st.subheader("π Analysis Reasoning") | |
for point in gemini_result.get('gemini_analysis', {}).get('reasoning', ['N/A']): | |
st.write(f"β’ {point}") | |
# Named Entities from spaCy | |
st.subheader("π·οΈ Named Entities") | |
entities = extract_entities(news_text, nlp) | |
df = pd.DataFrame(entities, columns=["Entity", "Type"]) | |
st.dataframe(df) | |
# Knowledge Graph Visualization | |
st.subheader("πΈοΈ Knowledge Graph") | |
fig = generate_knowledge_graph_viz(news_text, nlp, tokenizer, model) | |
st.plotly_chart(fig, use_container_width=True) | |
# Text Classification | |
st.subheader("π Text Classification") | |
text_class = gemini_result.get('text_classification', {}) | |
st.write(f"Category: {text_class.get('category', 'N/A')}") | |
st.write(f"Writing Style: {text_class.get('writing_style', 'N/A')}") | |
st.write(f"Target Audience: {text_class.get('target_audience', 'N/A')}") | |
st.write(f"Content Type: {text_class.get('content_type', 'N/A')}") | |
# Sentiment Analysis | |
st.subheader("π Sentiment Analysis") | |
sentiment = gemini_result.get('sentiment_analysis', {}) | |
st.write(f"Primary Emotion: {sentiment.get('primary_emotion', 'N/A')}") | |
st.write(f"Emotional Intensity: {sentiment.get('emotional_intensity', 'N/A')}/10") | |
st.write(f"Sensationalism Level: {sentiment.get('sensationalism_level', 'N/A')}") | |
st.write("Bias Indicators:", ", ".join(sentiment.get('bias_indicators', ['N/A']))) | |
# Entity Recognition | |
st.subheader("π Entity Recognition") | |
entities = gemini_result.get('entity_recognition', {}) | |
st.write(f"Source Credibility: {entities.get('source_credibility', 'N/A')}") | |
st.write("People:", ", ".join(entities.get('people', ['N/A']))) | |
st.write("Organizations:", ", ".join(entities.get('organizations', ['N/A']))) | |
except Exception as e: | |
st.error("Error processing analysis results") | |
else: | |
st.warning("Please enter some text to analyze") | |
if __name__ == "__main__": | |
main() | |