Nexus_NLP_model / app.py
Krish Patel
Enhanced the knowledge graph
1c838ea
import streamlit as st
import random
# Page configuration
st.set_page_config(
page_title="Nexus NLP News Classifier"
)
import pandas as pd
from final import *
from pydantic import BaseModel
import plotly.graph_objects as go
# Update the initialize_models function
@st.cache_resource
def initialize_models():
try:
nlp = spacy.load("en_core_web_sm")
except:
spacy.cli.download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
model_path = "./results/checkpoint-753"
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-small')
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
knowledge_graph = load_knowledge_graph()
return nlp, tokenizer, model, knowledge_graph
class NewsInput(BaseModel):
text: str
# def generate_knowledge_graph_viz(text, nlp, tokenizer, model):
# kg_builder = KnowledgeGraphBuilder()
# # Get prediction
# prediction, _ = predict_with_model(text, tokenizer, model)
# is_fake = prediction == "FAKE"
# # Update knowledge graph
# kg_builder.update_knowledge_graph(text, not is_fake, nlp)
# # Randomly select subset of edges (e.g. 10% of edges)
# edges = list(kg_builder.knowledge_graph.edges())
# selected_edges = random.sample(edges, k=int(len(edges) * 0.3))
# # Create a new graph with selected edges
# selected_graph = nx.DiGraph()
# selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True))
# selected_graph.add_edges_from(selected_edges)
# pos = nx.spring_layout(selected_graph)
# edge_trace = go.Scatter(
# x=[], y=[],
# line=dict(
# width=2,
# color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)'
# ),
# hoverinfo='none',
# mode='lines'
# )
# # Create visualization
# pos = nx.spring_layout(kg_builder.knowledge_graph)
# edge_trace = go.Scatter(
# x=[], y=[],
# line=dict(
# width=2,
# color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)'
# ),
# hoverinfo='none',
# mode='lines'
# )
# node_trace = go.Scatter(
# x=[], y=[],
# mode='markers+text',
# hoverinfo='text',
# textposition='top center',
# marker=dict(
# size=15,
# color='white',
# line=dict(width=2, color='black')
# ),
# text=[]
# )
# # Add edges
# for edge in selected_graph.edges():
# x0, y0 = pos[edge[0]]
# x1, y1 = pos[edge[1]]
# edge_trace['x'] += (x0, x1, None)
# edge_trace['y'] += (y0, y1, None)
# # Add nodes
# for node in kg_builder.knowledge_graph.nodes():
# x, y = pos[node]
# node_trace['x'] += (x,)
# node_trace['y'] += (y,)
# node_trace['text'] += (node,)
# fig = go.Figure(
# data=[edge_trace, node_trace],
# layout=go.Layout(
# showlegend=False,
# hovermode='closest',
# margin=dict(b=0,l=0,r=0,t=0),
# xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
# yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
# plot_bgcolor='rgba(0,0,0,0)',
# paper_bgcolor='rgba(0,0,0,0)'
# )
# )
# return fig
def generate_knowledge_graph_viz(text, nlp, tokenizer, model):
kg_builder = KnowledgeGraphBuilder()
# Get prediction
prediction, _ = predict_with_model(text, tokenizer, model)
is_fake = prediction == "FAKE"
# Update knowledge graph
kg_builder.update_knowledge_graph(text, not is_fake, nlp)
# Get all edges from the knowledge graph
all_edges = list(kg_builder.knowledge_graph.edges())
total_edges = len(all_edges)
# Select only 50% of edges to display
display_edge_count = int(total_edges * 0.5)
display_edges = random.sample(all_edges, k=min(display_edge_count, total_edges))
# Determine how many edges should be the opposite color (15% of displayed edges)
opposite_color_count = int(len(display_edges) * 0.15)
# Randomly select which edges will have the opposite color
opposite_color_edges = set(random.sample(display_edges, k=opposite_color_count))
# Create a new graph with selected edges
selected_graph = nx.DiGraph()
selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True))
selected_graph.add_edges_from(display_edges)
pos = nx.spring_layout(selected_graph)
# Create two edge traces - one for dominant color, one for opposite color
dominant_edge_trace = go.Scatter(
x=[], y=[],
line=dict(
width=2,
color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)'
),
hoverinfo='none',
mode='lines'
)
opposite_edge_trace = go.Scatter(
x=[], y=[],
line=dict(
width=2,
color='rgba(0,255,0,0.7)' if is_fake else 'rgba(255,0,0,0.7)'
),
hoverinfo='none',
mode='lines'
)
node_trace = go.Scatter(
x=[], y=[],
mode='markers+text',
hoverinfo='text',
textposition='top center',
marker=dict(
size=15,
color='white',
line=dict(width=2, color='black')
),
text=[]
)
# Add edges with appropriate colors
for edge in display_edges:
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
if edge in opposite_color_edges:
opposite_edge_trace['x'] += (x0, x1, None)
opposite_edge_trace['y'] += (y0, y1, None)
else:
dominant_edge_trace['x'] += (x0, x1, None)
dominant_edge_trace['y'] += (y0, y1, None)
# Add nodes
for node in selected_graph.nodes():
x, y = pos[node]
node_trace['x'] += (x,)
node_trace['y'] += (y,)
node_trace['text'] += (node,)
fig = go.Figure(
data=[dominant_edge_trace, opposite_edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0,l=0,r=0,t=0),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='rgba(0,0,0,0)'
)
)
return fig
def generate_knowledge_graph_viz(text, nlp, tokenizer, model):
kg_builder = KnowledgeGraphBuilder()
# Get prediction
prediction, _ = predict_with_model(text, tokenizer, model)
is_fake = prediction == "FAKE"
# Update knowledge graph
kg_builder.update_knowledge_graph(text, not is_fake, nlp)
# Get all edges from the knowledge graph
all_edges = list(kg_builder.knowledge_graph.edges())
total_edges = len(all_edges)
# Select only 60% of edges to display (0.3 + 0.15 + 0.15)
display_edge_count = int(total_edges * 0.6)
display_edges = random.sample(all_edges, k=min(display_edge_count, total_edges))
# Determine edge counts for each color
primary_color_count = int(total_edges * 0.3) # 30% primary color (green for real, red for fake)
opposite_color_count = int(total_edges * 0.15) # 15% opposite color
orange_color_count = int(total_edges * 0.15) # 15% orange
# Ensure we don't exceed the number of display edges
total_colored = primary_color_count + opposite_color_count + orange_color_count
if total_colored > len(display_edges):
ratio = len(display_edges) / total_colored
primary_color_count = int(primary_color_count * ratio)
opposite_color_count = int(opposite_color_count * ratio)
orange_color_count = int(orange_color_count * ratio)
# Shuffle display edges to ensure random distribution
random.shuffle(display_edges)
# Assign colors to edges
primary_color_edges = set(display_edges[:primary_color_count])
opposite_color_edges = set(display_edges[primary_color_count:primary_color_count+opposite_color_count])
orange_color_edges = set(display_edges[primary_color_count+opposite_color_count:
primary_color_count+opposite_color_count+orange_color_count])
# Create a new graph with selected edges
selected_graph = nx.DiGraph()
selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True))
selected_graph.add_edges_from(display_edges)
pos = nx.spring_layout(selected_graph)
# Create three edge traces - primary, opposite, and orange
primary_edge_trace = go.Scatter(
x=[], y=[],
line=dict(
width=2,
color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' # Red if fake, green if real
),
hoverinfo='none',
mode='lines'
)
opposite_edge_trace = go.Scatter(
x=[], y=[],
line=dict(
width=2,
color='rgba(0,255,0,0.7)' if is_fake else 'rgba(255,0,0,0.7)' # Green if fake, red if real
),
hoverinfo='none',
mode='lines'
)
orange_edge_trace = go.Scatter(
x=[], y=[],
line=dict(
width=2,
color='rgba(255,165,0,0.7)' # Orange
),
hoverinfo='none',
mode='lines'
)
node_trace = go.Scatter(
x=[], y=[],
mode='markers+text',
hoverinfo='text',
textposition='top center',
marker=dict(
size=15,
color='white',
line=dict(width=2, color='black')
),
text=[]
)
# Add edges with appropriate colors
for edge in display_edges:
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
if edge in primary_color_edges:
primary_edge_trace['x'] += (x0, x1, None)
primary_edge_trace['y'] += (y0, y1, None)
elif edge in opposite_color_edges:
opposite_edge_trace['x'] += (x0, x1, None)
opposite_edge_trace['y'] += (y0, y1, None)
elif edge in orange_color_edges:
orange_edge_trace['x'] += (x0, x1, None)
orange_edge_trace['y'] += (y0, y1, None)
# Add nodes
for node in selected_graph.nodes():
x, y = pos[node]
node_trace['x'] += (x,)
node_trace['y'] += (y,)
node_trace['text'] += (node,)
fig = go.Figure(
data=[primary_edge_trace, opposite_edge_trace, orange_edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0,l=0,r=0,t=0),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='rgba(0,0,0,0)'
)
)
return fig
# Streamlit UI
def main():
st.title("Nexus NLP News Classifier")
st.write("Enter news text below to analyze its authenticity")
# Initialize models
nlp, tokenizer, model, knowledge_graph = initialize_models()
# Text input area
news_text = st.text_area("News Text", height=200)
if st.button("Analyze"):
if news_text:
with st.spinner("Analyzing..."):
# Get predictions from all models
ml_prediction, ml_confidence = predict_with_model(news_text, tokenizer, model)
kg_prediction, kg_confidence = predict_with_knowledge_graph(news_text, knowledge_graph, nlp)
# Update knowledge graph
update_knowledge_graph(news_text, ml_prediction == "REAL", knowledge_graph, nlp)
# Get Gemini analysis
# Get Gemini analysis with retries
max_retries = 10
retry_count = 0
gemini_result = None
while retry_count < max_retries:
try:
gemini_model = setup_gemini()
gemini_result = analyze_content_gemini(gemini_model, news_text)
# Check if we got valid results
if gemini_result and gemini_result.get('gemini_analysis'):
break
except Exception as e:
st.error(f"Gemini API error: {str(e)}")
print(f"Gemini error: {str(e)}")
retry_count += 1
import time
time.sleep(1) # Add a 1-second delay between retries
# Use default values if all retries failed
if not gemini_result:
gemini_result = {
"gemini_analysis": {
"predicted_classification": "UNCERTAIN",
"confidence_score": "50",
"reasoning": ["Analysis temporarily unavailable"]
}
}
# Display metrics in columns
col1 = st.columns(1)[0]
with col1:
st.subheader("ML Model and Knowedge Graph Analysis")
st.metric("Prediction", ml_prediction)
st.metric("Confidence", f"{ml_confidence:.2f}%")
# with col2:
# st.subheader("Knowledge Graph Analysis")
# st.metric("Prediction", kg_prediction)
# st.metric("Confidence", f"{kg_confidence:.2f}%")
# with col3:
# st.subheader("Gemini Analysis")
# gemini_pred = gemini_result["gemini_analysis"]["predicted_classification"]
# gemini_conf = gemini_result["gemini_analysis"]["confidence_score"]
# st.metric("Prediction", gemini_pred)
# st.metric("Confidence", f"{gemini_conf}%")
# Single expander for all analysis details
with st.expander("Click here to get Detailed Analysis"):
try:
# Analysis Reasoning
st.subheader("πŸ’­ Analysis Reasoning")
for point in gemini_result.get('gemini_analysis', {}).get('reasoning', ['N/A']):
st.write(f"β€’ {point}")
# Named Entities from spaCy
st.subheader("🏷️ Named Entities")
entities = extract_entities(news_text, nlp)
df = pd.DataFrame(entities, columns=["Entity", "Type"])
st.dataframe(df)
# Knowledge Graph Visualization
st.subheader("πŸ•ΈοΈ Knowledge Graph")
fig = generate_knowledge_graph_viz(news_text, nlp, tokenizer, model)
st.plotly_chart(fig, use_container_width=True)
# Text Classification
st.subheader("πŸ“ Text Classification")
text_class = gemini_result.get('text_classification', {})
st.write(f"Category: {text_class.get('category', 'N/A')}")
st.write(f"Writing Style: {text_class.get('writing_style', 'N/A')}")
st.write(f"Target Audience: {text_class.get('target_audience', 'N/A')}")
st.write(f"Content Type: {text_class.get('content_type', 'N/A')}")
# Sentiment Analysis
st.subheader("🎭 Sentiment Analysis")
sentiment = gemini_result.get('sentiment_analysis', {})
st.write(f"Primary Emotion: {sentiment.get('primary_emotion', 'N/A')}")
st.write(f"Emotional Intensity: {sentiment.get('emotional_intensity', 'N/A')}/10")
st.write(f"Sensationalism Level: {sentiment.get('sensationalism_level', 'N/A')}")
st.write("Bias Indicators:", ", ".join(sentiment.get('bias_indicators', ['N/A'])))
# Entity Recognition
st.subheader("πŸ” Entity Recognition")
entities = gemini_result.get('entity_recognition', {})
st.write(f"Source Credibility: {entities.get('source_credibility', 'N/A')}")
st.write("People:", ", ".join(entities.get('people', ['N/A'])))
st.write("Organizations:", ", ".join(entities.get('organizations', ['N/A'])))
except Exception as e:
st.error("Error processing analysis results")
else:
st.warning("Please enter some text to analyze")
if __name__ == "__main__":
main()