Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
from transformers import pipeline | |
from sentence_transformers import CrossEncoder | |
from sentence_transformers import SentenceTransformer | |
import string | |
from nltk.tokenize import word_tokenize | |
from nltk.corpus import stopwords | |
from nltk.stem import WordNetLemmatizer | |
import nltk | |
# Download NLTK resources (run this once if not already downloaded) | |
nltk.download('punkt') | |
nltk.download('punkt_tab') | |
nltk.download('stopwords') | |
nltk.download('wordnet') | |
# Set modern page configuration | |
st.set_page_config(page_title="News Analyzer", layout="wide") | |
# Inject custom CSS for sleek dark blue theme with black fonts | |
st.markdown(""" | |
<style> | |
/* Global Styling */ | |
body { | |
background: #0b132b; | |
font-family: 'Arial', sans-serif; | |
margin: 0; | |
padding: 0; | |
} | |
/* Header Styling */ | |
.custom-header { | |
background: linear-gradient(to right, #1f4068, #1b1b2f); | |
padding: 1.5rem; | |
margin-bottom: 1.5rem; | |
border-radius: 12px; | |
text-align: center; | |
font-size: 30px; | |
font-weight: bold; | |
box-shadow: 0px 4px 15px rgba(0, 217, 255, 0.3); | |
} | |
/* Buttons */ | |
.stButton>button { | |
background: linear-gradient(45deg, #0072ff, #00c6ff); | |
border-radius: 8px; | |
padding: 14px 28px; | |
font-size: 18px; | |
transition: 0.3s ease; | |
border: none; | |
} | |
.stButton>button:hover { | |
transform: scale(1.05); | |
box-shadow: 0px 4px 10px rgba(0, 255, 255, 0.5); | |
} | |
/* Text Input */ | |
.stTextInput>div>div>input { | |
background-color: rgba(255, 255, 255, 0.1); | |
border-radius: 8px; | |
padding: 12px; | |
font-size: 18px; | |
} | |
/* Dataframe Container */ | |
.dataframe-container { | |
background: rgba(255, 255, 255, 0.1); | |
padding: 15px; | |
border-radius: 12px; | |
} | |
/* Answer Display Box - Larger */ | |
.answer-box { | |
background: rgba(0, 217, 255, 0.15); | |
padding: 35px; | |
border-radius: 15px; | |
border: 2px solid rgba(0, 217, 255, 0.6); | |
font-size: 22px; | |
text-align: center; | |
margin-bottom: 20px; | |
min-height: 150px; | |
box-shadow: 0px 2px 12px rgba(0, 217, 255, 0.3); | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
transition: all 0.3s ease; | |
} | |
/* CSV Display Box */ | |
.csv-box { | |
background: rgba(255, 255, 255, 0.1); | |
padding: 15px; | |
border-radius: 12px; | |
margin-top: 20px; | |
box-shadow: 0px 2px 12px rgba(0, 217, 255, 0.3); | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Modern Header | |
st.markdown("<div class='custom-header'> ๐งฉ AI-Powered News Analyzer</div>", unsafe_allow_html=True) | |
# Load the Hugging Face models | |
classifier = pipeline("text-classification", model="Sandini/news-classifier") # Classification pipeline | |
qa_pipeline = pipeline("question-answering", model="distilbert/distilbert-base-cased-distilled-squad") # QA pipeline | |
# Initialize Cross-Encoder for QA relevance scoring | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # Pre-trained Cross-Encoder model | |
# Define preprocessing functions for classification | |
def preprocess_text(text): | |
if not isinstance(text, str): | |
text = "" | |
# Step 1: Lowercase the text | |
text = text.lower() | |
# Step 2: Remove punctuation | |
text = text.translate(str.maketrans('', '', string.punctuation)) | |
# Step 3: Tokenize the text | |
tokens = word_tokenize(text) | |
# Step 4: Remove stopwords | |
stop_words = set(stopwords.words('english')) | |
tokens = [word for word in tokens if word not in stop_words] | |
# Step 5: Lemmatization | |
lemmatizer = WordNetLemmatizer() | |
tokens = [lemmatizer.lemmatize(word) for word in tokens] | |
# Step 6: Join tokens back into a single string | |
preprocessed_text = " ".join(tokens) | |
return preprocessed_text | |
# Reverse mapping (numeric label -> category name) | |
label_mapping = { | |
"Business": 0, | |
"Opinion": 1, | |
"Sports": 2, | |
"Political_gossip": 3, | |
"World_news": 4 | |
} | |
reverse_label_mapping = {v: k for k, v in label_mapping.items()} | |
# Define a function to predict the category for a single text | |
def predict_category(text): | |
prediction = classifier(text) | |
predicted_label_id = int(prediction[0]['label'].split('_')[-1]) # Extract numeric label from 'LABEL_X' | |
return reverse_label_mapping[predicted_label_id] | |
# Responsive Layout - Uses full width | |
col1, col2 = st.columns([1.1, 1]) | |
# Left Section - File Upload & CSV/Excel Display | |
with col1: | |
st.subheader("๐ Upload News Data") | |
uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"]) | |
if uploaded_file is not None: | |
# Determine the file extension | |
file_extension = uploaded_file.name.split('.')[-1] | |
if file_extension == 'csv': | |
df = pd.read_csv(uploaded_file) | |
elif file_extension == 'xlsx': | |
df = pd.read_excel(uploaded_file) | |
# Preprocess the content column and predict categories | |
if 'content' in df.columns: | |
df['content'] = df['content'].fillna("").astype(str) | |
df['preprocessed_content'] = df['content'].apply(preprocess_text) | |
df['class'] = df['preprocessed_content'].apply(predict_category) | |
# Drop the preprocessed_content column before displaying or saving | |
df_for_display = df.drop(columns=['preprocessed_content'], errors='ignore') | |
df_for_download = df.drop(columns=['preprocessed_content'], errors='ignore') | |
# Download button | |
st.download_button( | |
label="โฌ๏ธ Download Processed Data", | |
data=df_for_download.to_csv(index=False).encode('utf-8'), | |
file_name="output.csv", | |
mime="text/csv" | |
) | |
# CSV Preview Box | |
st.markdown("<div class='csv-box'><h4>๐ CSV/Excel Preview</h4></div>", unsafe_allow_html=True) | |
st.dataframe(df_for_display, use_container_width=True) | |
# Right Section - Q&A Interface | |
with col2: | |
st.subheader("๐ค AI Assistant") | |
# Answer Display Box (Initially Empty) | |
answer_placeholder = st.empty() | |
answer_placeholder.markdown("<div class='answer-box'></div>", unsafe_allow_html=True) | |
# Question Input | |
st.markdown("### ๐ Ask Your Question:") | |
user_question = st.text_input("Enter your question here", label_visibility="hidden") # Hides the label | |
# Button & Answer Display | |
if st.button("๐ฎ Get Answer"): | |
if user_question.strip() and uploaded_file is not None: | |
# Ensure the DataFrame has the required content column | |
if 'content' in df.columns: | |
context = df['content'].dropna().tolist() # Use the content column as context | |
# Prepare pairs of (question, context) | |
pairs = [(user_question, c) for c in context] | |
# Score each pair using the Cross-Encoder | |
scores = cross_encoder.predict(pairs) | |
# Get top matches based on scores | |
top_indices = scores.argsort()[-5:][::-1] # Get indices of top 5 matches | |
top_context = "\n".join([context[i] for i in top_indices]) | |
# Get answer from Hugging Face model using top context | |
result = qa_pipeline(question=user_question, context=top_context) | |
answer = result['answer'] | |
else: | |
answer = "โ ๏ธ File does not contain a 'content' column!" | |
else: | |
answer = "โ ๏ธ Please upload a valid file first!" | |
answer_placeholder.markdown(f"<div class='answer-box'>{answer}</div>", unsafe_allow_html=True) |