Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import re | |
import string | |
import nltk | |
from nltk.corpus import stopwords | |
from nltk.stem import WordNetLemmatizer | |
from transformers import pipeline | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from wordcloud import WordCloud | |
# Download required NLTK data | |
nltk.download('stopwords') | |
nltk.download('wordnet') | |
nltk.download('omw-1.4') | |
# Load Models | |
news_classifier = pipeline("text-classification", model="Oneli/News_Classification") | |
qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad") | |
# Label Mapping | |
label_mapping = { | |
"LABEL_0": "Business", | |
"LABEL_1": "Opinion", | |
"LABEL_2": "Political Gossip", | |
"LABEL_3": "Sports", | |
"LABEL_4": "World News" | |
} | |
# Store classified article for QA | |
context_storage = {"context": "", "bulk_context": "", "num_articles": 0} | |
# Text Cleaning Functions | |
def clean_text(text): | |
text = text.lower() | |
text = re.sub(f"[{string.punctuation}]", "", text) # Remove punctuation | |
text = re.sub(r"[^a-zA-Z0-9\s]", "", text) # Remove special characters | |
words = text.split() # Tokenization without Punkt | |
words = [word for word in words if word not in stopwords.words("english")] # Remove stopwords | |
lemmatizer = WordNetLemmatizer() | |
words = [lemmatizer.lemmatize(word) for word in words] # Lemmatize tokens | |
return " ".join(words) | |
# Define the functions | |
def classify_text(text): | |
cleaned_text = clean_text(text) | |
result = news_classifier(cleaned_text)[0] | |
category = label_mapping.get(result['label'], "Unknown") | |
confidence = round(result['score'] * 100, 2) | |
# Store context for QA | |
context_storage["context"] = cleaned_text | |
return category, f"Confidence: {confidence}%" | |
def classify_csv(file): | |
try: | |
df = pd.read_csv(file, encoding="utf-8") | |
text_column = df.columns[0] # Assume first column is the text column | |
df[text_column] = df[text_column].astype(str).apply(clean_text) # Clean text column | |
df["Decoded Prediction"] = df[text_column].apply(lambda x: label_mapping.get(news_classifier(x)[0]['label'], "Unknown")) | |
df["Confidence"] = df[text_column].apply(lambda x: round(news_classifier(x)[0]['score'] * 100, 2)) | |
# Store all text as a single context for QA | |
context_storage["bulk_context"] = " ".join(df[text_column].dropna().astype(str).tolist()) | |
context_storage["num_articles"] = len(df) | |
output_file = "output.csv" | |
df.to_csv(output_file, index=False) | |
return df, output_file | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
def chatbot_response(history, user_input, text_input=None, file_input=None): | |
user_input = user_input.lower() | |
context = "" | |
if text_input: | |
context += text_input | |
if file_input: | |
df, _ = classify_csv(file_input) | |
context += context_storage["bulk_context"] | |
if context: | |
with st.spinner("Finding answer..."): | |
result = qa_pipeline(question=user_input, context=context) | |
answer = result["answer"] | |
history.append([user_input, answer]) | |
return history, answer | |
# Function to generate word cloud from the 'content' column (from output CSV) | |
def generate_word_cloud_from_output(df): | |
# Assuming 'content' column is the first column after processing | |
content_text = " ".join(df["content"].dropna().astype(str).tolist()) | |
wordcloud = WordCloud(width=800, height=400, background_color="white").generate(content_text) | |
return wordcloud | |
# Function to generate bar graph for decoded predictions | |
def generate_bar_graph(df): | |
prediction_counts = df["Decoded Prediction"].value_counts() | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
prediction_counts.plot(kind='bar', ax=ax, color='skyblue') | |
ax.set_title('Frequency of Decoded Predictions', fontsize=16) | |
ax.set_xlabel('Category', fontsize=12) | |
ax.set_ylabel('Frequency', fontsize=12) | |
st.pyplot(fig) | |
# Streamlit App Layout | |
st.set_page_config(page_title="News Classifier", page_icon="π°") | |
# Load image | |
cover_image = Image.open("cover.png") # Ensure this image exists | |
# Display image | |
st.image(cover_image, use_container_width=True) | |
# Custom styled caption | |
st.markdown( | |
"<h2 style='text-align: center; font-size: 32px;'>News Classifier π’</h2>", | |
unsafe_allow_html=True | |
) | |
# Section for Single Article Classification | |
st.subheader("π° Single Article Classification") | |
text_input = st.text_area("Enter News Text", placeholder="Type or paste news content here...") | |
if st.button("π Classify"): | |
if text_input: | |
category, confidence = classify_text(text_input) | |
st.write(f"Predicted Category: {category}") | |
st.write(f"Confidence Level: {confidence}") | |
# Generate word cloud for the cleaned text input | |
wordcloud = generate_word_cloud_from_output(pd.DataFrame({"content": [text_input]})) # Create a DataFrame for single input | |
st.image(wordcloud.to_array(), caption="Word Cloud for Text Input", use_container_width=True) | |
else: | |
st.warning("Please enter some text to classify.") | |
# Section for Bulk CSV Classification | |
st.subheader("π Bulk Classification (CSV)") | |
file_input = st.file_uploader("Upload CSV File", type="csv") | |
if file_input: | |
df, output_file = classify_csv(file_input) | |
if df is not None: | |
st.dataframe(df) | |
st.download_button( | |
label="Download Processed CSV", | |
data=open(output_file, 'rb').read(), | |
file_name=output_file, | |
mime="text/csv" | |
) | |
# Generate word cloud for the 'content' column of the processed CSV data | |
wordcloud = generate_word_cloud_from_output(df) | |
st.image(wordcloud.to_array(), caption="Word Cloud for CSV Content", use_container_width=True) | |
# Generate bar graph for decoded predictions frequency | |
generate_bar_graph(df) | |
else: | |
st.error(f"Error processing file: {output_file}") | |
# Section for Chatbot Interaction | |
st.subheader("π¬ AI Chat Assistant") | |
history = [] | |
user_input = st.text_input("Ask about news classification or topics", placeholder="Type a message...") | |
source_toggle = st.radio("Select Context Source", ["Single Article", "Bulk Classification"]) | |
if st.button("β Send"): | |
if not user_input and not file_input: | |
st.warning("Please upload your file or provide text input for QA.") | |
else: | |
history, bot_response = chatbot_response( | |
history, | |
user_input, | |
text_input=text_input if source_toggle == "Single Article" else None, | |
file_input=file_input if source_toggle == "Bulk Classification" else None | |
) | |
st.write("Chatbot Response:") | |
for q, a in history: | |
st.write(f"Q: {q}") | |
st.write(f"A: {a}") | |