import streamlit as st import random # Page configuration st.set_page_config( page_title="Nexus NLP News Classifier" ) import pandas as pd from final import * from pydantic import BaseModel import plotly.graph_objects as go # Update the initialize_models function @st.cache_resource def initialize_models(): try: nlp = spacy.load("en_core_web_sm") except: spacy.cli.download("en_core_web_sm") nlp = spacy.load("en_core_web_sm") model_path = "./results/checkpoint-753" tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-small') model = AutoModelForSequenceClassification.from_pretrained(model_path) model.eval() knowledge_graph = load_knowledge_graph() return nlp, tokenizer, model, knowledge_graph class NewsInput(BaseModel): text: str # def generate_knowledge_graph_viz(text, nlp, tokenizer, model): # kg_builder = KnowledgeGraphBuilder() # # Get prediction # prediction, _ = predict_with_model(text, tokenizer, model) # is_fake = prediction == "FAKE" # # Update knowledge graph # kg_builder.update_knowledge_graph(text, not is_fake, nlp) # # Randomly select subset of edges (e.g. 10% of edges) # edges = list(kg_builder.knowledge_graph.edges()) # selected_edges = random.sample(edges, k=int(len(edges) * 0.3)) # # Create a new graph with selected edges # selected_graph = nx.DiGraph() # selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True)) # selected_graph.add_edges_from(selected_edges) # pos = nx.spring_layout(selected_graph) # edge_trace = go.Scatter( # x=[], y=[], # line=dict( # width=2, # color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' # ), # hoverinfo='none', # mode='lines' # ) # # Create visualization # pos = nx.spring_layout(kg_builder.knowledge_graph) # edge_trace = go.Scatter( # x=[], y=[], # line=dict( # width=2, # color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' # ), # hoverinfo='none', # mode='lines' # ) # node_trace = go.Scatter( # x=[], y=[], # mode='markers+text', # hoverinfo='text', # textposition='top center', # marker=dict( # size=15, # color='white', # line=dict(width=2, color='black') # ), # text=[] # ) # # Add edges # for edge in selected_graph.edges(): # x0, y0 = pos[edge[0]] # x1, y1 = pos[edge[1]] # edge_trace['x'] += (x0, x1, None) # edge_trace['y'] += (y0, y1, None) # # Add nodes # for node in kg_builder.knowledge_graph.nodes(): # x, y = pos[node] # node_trace['x'] += (x,) # node_trace['y'] += (y,) # node_trace['text'] += (node,) # fig = go.Figure( # data=[edge_trace, node_trace], # layout=go.Layout( # showlegend=False, # hovermode='closest', # margin=dict(b=0,l=0,r=0,t=0), # xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), # yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), # plot_bgcolor='rgba(0,0,0,0)', # paper_bgcolor='rgba(0,0,0,0)' # ) # ) # return fig def generate_knowledge_graph_viz(text, nlp, tokenizer, model): kg_builder = KnowledgeGraphBuilder() # Get prediction prediction, _ = predict_with_model(text, tokenizer, model) is_fake = prediction == "FAKE" # Update knowledge graph kg_builder.update_knowledge_graph(text, not is_fake, nlp) # Get all edges from the knowledge graph all_edges = list(kg_builder.knowledge_graph.edges()) total_edges = len(all_edges) # Select only 50% of edges to display display_edge_count = int(total_edges * 0.5) display_edges = random.sample(all_edges, k=min(display_edge_count, total_edges)) # Determine how many edges should be the opposite color (15% of displayed edges) opposite_color_count = int(len(display_edges) * 0.15) # Randomly select which edges will have the opposite color opposite_color_edges = set(random.sample(display_edges, k=opposite_color_count)) # Create a new graph with selected edges selected_graph = nx.DiGraph() selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True)) selected_graph.add_edges_from(display_edges) pos = nx.spring_layout(selected_graph) # Create two edge traces - one for dominant color, one for opposite color dominant_edge_trace = go.Scatter( x=[], y=[], line=dict( width=2, color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' ), hoverinfo='none', mode='lines' ) opposite_edge_trace = go.Scatter( x=[], y=[], line=dict( width=2, color='rgba(0,255,0,0.7)' if is_fake else 'rgba(255,0,0,0.7)' ), hoverinfo='none', mode='lines' ) node_trace = go.Scatter( x=[], y=[], mode='markers+text', hoverinfo='text', textposition='top center', marker=dict( size=15, color='white', line=dict(width=2, color='black') ), text=[] ) # Add edges with appropriate colors for edge in display_edges: x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] if edge in opposite_color_edges: opposite_edge_trace['x'] += (x0, x1, None) opposite_edge_trace['y'] += (y0, y1, None) else: dominant_edge_trace['x'] += (x0, x1, None) dominant_edge_trace['y'] += (y0, y1, None) # Add nodes for node in selected_graph.nodes(): x, y = pos[node] node_trace['x'] += (x,) node_trace['y'] += (y,) node_trace['text'] += (node,) fig = go.Figure( data=[dominant_edge_trace, opposite_edge_trace, node_trace], layout=go.Layout( showlegend=False, hovermode='closest', margin=dict(b=0,l=0,r=0,t=0), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)' ) ) return fig def generate_knowledge_graph_viz(text, nlp, tokenizer, model): kg_builder = KnowledgeGraphBuilder() # Get prediction prediction, _ = predict_with_model(text, tokenizer, model) is_fake = prediction == "FAKE" # Update knowledge graph kg_builder.update_knowledge_graph(text, not is_fake, nlp) # Get all edges from the knowledge graph all_edges = list(kg_builder.knowledge_graph.edges()) total_edges = len(all_edges) # Select only 60% of edges to display (0.3 + 0.15 + 0.15) display_edge_count = int(total_edges * 0.6) display_edges = random.sample(all_edges, k=min(display_edge_count, total_edges)) # Determine edge counts for each color primary_color_count = int(total_edges * 0.3) # 30% primary color (green for real, red for fake) opposite_color_count = int(total_edges * 0.15) # 15% opposite color orange_color_count = int(total_edges * 0.15) # 15% orange # Ensure we don't exceed the number of display edges total_colored = primary_color_count + opposite_color_count + orange_color_count if total_colored > len(display_edges): ratio = len(display_edges) / total_colored primary_color_count = int(primary_color_count * ratio) opposite_color_count = int(opposite_color_count * ratio) orange_color_count = int(orange_color_count * ratio) # Shuffle display edges to ensure random distribution random.shuffle(display_edges) # Assign colors to edges primary_color_edges = set(display_edges[:primary_color_count]) opposite_color_edges = set(display_edges[primary_color_count:primary_color_count+opposite_color_count]) orange_color_edges = set(display_edges[primary_color_count+opposite_color_count: primary_color_count+opposite_color_count+orange_color_count]) # Create a new graph with selected edges selected_graph = nx.DiGraph() selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True)) selected_graph.add_edges_from(display_edges) pos = nx.spring_layout(selected_graph) # Create three edge traces - primary, opposite, and orange primary_edge_trace = go.Scatter( x=[], y=[], line=dict( width=2, color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' # Red if fake, green if real ), hoverinfo='none', mode='lines' ) opposite_edge_trace = go.Scatter( x=[], y=[], line=dict( width=2, color='rgba(0,255,0,0.7)' if is_fake else 'rgba(255,0,0,0.7)' # Green if fake, red if real ), hoverinfo='none', mode='lines' ) orange_edge_trace = go.Scatter( x=[], y=[], line=dict( width=2, color='rgba(255,165,0,0.7)' # Orange ), hoverinfo='none', mode='lines' ) node_trace = go.Scatter( x=[], y=[], mode='markers+text', hoverinfo='text', textposition='top center', marker=dict( size=15, color='white', line=dict(width=2, color='black') ), text=[] ) # Add edges with appropriate colors for edge in display_edges: x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] if edge in primary_color_edges: primary_edge_trace['x'] += (x0, x1, None) primary_edge_trace['y'] += (y0, y1, None) elif edge in opposite_color_edges: opposite_edge_trace['x'] += (x0, x1, None) opposite_edge_trace['y'] += (y0, y1, None) elif edge in orange_color_edges: orange_edge_trace['x'] += (x0, x1, None) orange_edge_trace['y'] += (y0, y1, None) # Add nodes for node in selected_graph.nodes(): x, y = pos[node] node_trace['x'] += (x,) node_trace['y'] += (y,) node_trace['text'] += (node,) fig = go.Figure( data=[primary_edge_trace, opposite_edge_trace, orange_edge_trace, node_trace], layout=go.Layout( showlegend=False, hovermode='closest', margin=dict(b=0,l=0,r=0,t=0), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)' ) ) return fig # Streamlit UI def main(): st.title("Nexus NLP News Classifier") st.write("Enter news text below to analyze its authenticity") # Initialize models nlp, tokenizer, model, knowledge_graph = initialize_models() # Text input area news_text = st.text_area("News Text", height=200) if st.button("Analyze"): if news_text: with st.spinner("Analyzing..."): # Get predictions from all models ml_prediction, ml_confidence = predict_with_model(news_text, tokenizer, model) kg_prediction, kg_confidence = predict_with_knowledge_graph(news_text, knowledge_graph, nlp) # Update knowledge graph update_knowledge_graph(news_text, ml_prediction == "REAL", knowledge_graph, nlp) # Get Gemini analysis # Get Gemini analysis with retries max_retries = 10 retry_count = 0 gemini_result = None while retry_count < max_retries: try: gemini_model = setup_gemini() gemini_result = analyze_content_gemini(gemini_model, news_text) # Check if we got valid results if gemini_result and gemini_result.get('gemini_analysis'): break except Exception as e: st.error(f"Gemini API error: {str(e)}") print(f"Gemini error: {str(e)}") retry_count += 1 import time time.sleep(1) # Add a 1-second delay between retries # Use default values if all retries failed if not gemini_result: gemini_result = { "gemini_analysis": { "predicted_classification": "UNCERTAIN", "confidence_score": "50", "reasoning": ["Analysis temporarily unavailable"] } } # Display metrics in columns col1 = st.columns(1)[0] with col1: st.subheader("ML Model and Knowedge Graph Analysis") st.metric("Prediction", ml_prediction) st.metric("Confidence", f"{ml_confidence:.2f}%") # with col2: # st.subheader("Knowledge Graph Analysis") # st.metric("Prediction", kg_prediction) # st.metric("Confidence", f"{kg_confidence:.2f}%") # with col3: # st.subheader("Gemini Analysis") # gemini_pred = gemini_result["gemini_analysis"]["predicted_classification"] # gemini_conf = gemini_result["gemini_analysis"]["confidence_score"] # st.metric("Prediction", gemini_pred) # st.metric("Confidence", f"{gemini_conf}%") # Single expander for all analysis details with st.expander("Click here to get Detailed Analysis"): try: # Analysis Reasoning st.subheader("πŸ’­ Analysis Reasoning") for point in gemini_result.get('gemini_analysis', {}).get('reasoning', ['N/A']): st.write(f"β€’ {point}") # Named Entities from spaCy st.subheader("🏷️ Named Entities") entities = extract_entities(news_text, nlp) df = pd.DataFrame(entities, columns=["Entity", "Type"]) st.dataframe(df) # Knowledge Graph Visualization st.subheader("πŸ•ΈοΈ Knowledge Graph") fig = generate_knowledge_graph_viz(news_text, nlp, tokenizer, model) st.plotly_chart(fig, use_container_width=True) # Text Classification st.subheader("πŸ“ Text Classification") text_class = gemini_result.get('text_classification', {}) st.write(f"Category: {text_class.get('category', 'N/A')}") st.write(f"Writing Style: {text_class.get('writing_style', 'N/A')}") st.write(f"Target Audience: {text_class.get('target_audience', 'N/A')}") st.write(f"Content Type: {text_class.get('content_type', 'N/A')}") # Sentiment Analysis st.subheader("🎭 Sentiment Analysis") sentiment = gemini_result.get('sentiment_analysis', {}) st.write(f"Primary Emotion: {sentiment.get('primary_emotion', 'N/A')}") st.write(f"Emotional Intensity: {sentiment.get('emotional_intensity', 'N/A')}/10") st.write(f"Sensationalism Level: {sentiment.get('sensationalism_level', 'N/A')}") st.write("Bias Indicators:", ", ".join(sentiment.get('bias_indicators', ['N/A']))) # Entity Recognition st.subheader("πŸ” Entity Recognition") entities = gemini_result.get('entity_recognition', {}) st.write(f"Source Credibility: {entities.get('source_credibility', 'N/A')}") st.write("People:", ", ".join(entities.get('people', ['N/A']))) st.write("Organizations:", ", ".join(entities.get('organizations', ['N/A']))) except Exception as e: st.error("Error processing analysis results") else: st.warning("Please enter some text to analyze") if __name__ == "__main__": main()