TAgroup5's picture
Update app.py
2ec8bb5 verified
raw
history blame
4.84 kB
import streamlit as st
import pandas as pd
import re
import io
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers import AutoModelForQuestionAnswering
# Load fine-tuned models and tokenizers for both functions
model_name_classification = "TAgroup5/news-classification-model"
model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
model_name_qa = "distilbert-base-cased-distilled-squad"
model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa)
tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)
# Initialize pipelines
from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
model_name = "distilbert-base-cased-distilled-squad" # Example model
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
qa_pipeline = pipeline("question-answering", model=model)
# Streamlit App
st.set_page_config(page_title="News Classification & Q&A", page_icon="πŸ“°", layout="wide")
st.markdown(
"""
<style>
body {background-color: #f4f4f4;}
.title {text-align: center; font-size: 36px; font-weight: bold; color: #ff4b4b;}
.subheader {font-size: 24px; color: #333; margin-bottom: 20px; text-align: right;}
.stTextInput>div>div>input {border-radius: 10px;}
.stTextArea>div>div>textarea {border-radius: 10px;}
.stButton>button {border-radius: 10px; background-color: #ff4b4b; color: white; font-weight: bold;}
</style>
""",
unsafe_allow_html=True,
)
st.markdown('<h1 class="title">πŸ“° News Classification & Q&A App</h1>', unsafe_allow_html=True)
col1, col2 = st.columns([2, 1])
with col2:
# ====================== Component 1: News Classification ====================== #
st.markdown('<h2 class="subheader">πŸ“Œ Classify News Articles</h2>', unsafe_allow_html=True)
st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
try:
df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues
except UnicodeDecodeError:
df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
if 'content' not in df.columns:
st.error("❌ Error: The uploaded CSV must contain a 'content' column.")
else:
st.success("βœ… File successfully uploaded!")
st.write("Preview of uploaded data:")
st.dataframe(df.head())
# Preprocessing function to clean the text
def preprocess_text(text):
text = text.lower() # Convert to lowercase
text = re.sub(r'\s+', ' ', text) # Remove extra spaces
text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
return text
# Apply preprocessing and classification
df['processed_content'] = df['content'].apply(preprocess_text)
df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
# Show results
st.markdown("### πŸ”Ή Classification Results:")
st.dataframe(df[['content', 'class']])
# Provide CSV download
output = io.BytesIO()
df.to_csv(output, index=False, encoding="utf-8-sig")
st.download_button(label="⬇️ Download classified news", data=output.getvalue(), file_name="classified_news.csv", mime="text/csv")
# ====================== Component 2: Q&A ====================== #
st.markdown('<h2 class="subheader">❓ Ask a Question About the News</h2>', unsafe_allow_html=True)
st.markdown("Enter a question and provide a news article to get an answer.")
question = st.text_input("πŸ” Ask a question:")
context = st.text_area("πŸ“ Provide the news article or content:", height=150)
if question and context.strip():
model_name_qa = "distilbert-base-uncased-distilled-squad"
qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
result = qa_pipeline(question=question, context=context)
# Display Answer
if 'answer' in result and result['answer']:
st.markdown(f"### βœ… Answer: {result['answer']}")
else:
st.markdown("### ❌ No answer found in the provided content.")