TAgroup5's picture
Update app.py
ebaf596 verified
raw
history blame
3.54 kB
import streamlit as st
import pandas as pd
import re
import io
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
# Load fine-tuned models and tokenizers for both functions
model_name_classification = "TAgroup5/news-classification-model" # Replace with the correct model name
model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
model_name_qa = "distilbert-base-cased-distilled-squad"
model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa)
tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)
# Initialize pipelines
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
# Streamlit App
st.title("News Classification and Q&A")
## ====================== Component 1: News Classification ====================== ##
st.header("Classify News Articles")
st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
try:
df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues
except UnicodeDecodeError:
df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
if 'content' not in df.columns:
st.error("Error: The uploaded CSV must contain a 'content' column.")
else:
st.write("Preview of uploaded data:")
st.dataframe(df.head())
# Preprocessing function to clean the text
def preprocess_text(text):
text = text.lower() # Convert to lowercase
text = re.sub(r'\s+', ' ', text) # Remove extra spaces
text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
# You don't need tokenization here, as the model tokenizer will handle it
return text
# Apply preprocessing and classification
df['processed_content'] = df['content'].apply(preprocess_text)
# Classify each record into one of the five classes
df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
# Show results
st.write("Classification Results:")
st.dataframe(df[['content', 'class']])
# Provide CSV download
output = io.BytesIO()
df.to_csv(output, index=False, encoding="utf-8-sig")
st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv")
## ====================== Component 2: Q&A ====================== ##
st.header("Ask a Question About the News")
st.markdown("Enter a question and provide a news article to get an answer.")
question = st.text_input("Ask a question:")
context = st.text_area("Provide the news article or content for the Q&A:", height=150)
if question and context.strip():
model_name_qa = "distilbert-base-uncased-distilled-squad" # Example of a common Q&A model
qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
result = qa_pipeline(question=question, context=context)
# Check if the result contains an answer
if 'answer' in result and result['answer']:
st.write("Answer:", result['answer'])
else:
st.write("No answer found in the provided content.")