Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import re | |
import io | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
from transformers import AutoModelForQuestionAnswering | |
# Load fine-tuned models and tokenizers for both functions | |
model_name_classification = "TAgroup5/news-classification-model" | |
model = AutoModelForSequenceClassification.from_pretrained(model_name_classification) | |
tokenizer = AutoTokenizer.from_pretrained(model_name_classification) | |
model_name_qa = "distilbert-base-cased-distilled-squad" | |
model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa) | |
tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa) | |
# Initialize pipelines | |
from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer | |
model_name = "distilbert-base-cased-distilled-squad" # Example model | |
model = AutoModelForQuestionAnswering.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer) | |
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
qa_pipeline = pipeline("question-answering", model=model) | |
# Streamlit App | |
st.set_page_config(page_title="News Classification & Q&A", page_icon="π°", layout="wide") | |
st.markdown( | |
""" | |
<style> | |
body {background-color: #f4f4f4;} | |
.title {text-align: center; font-size: 36px; font-weight: bold; color: #ff4b4b;} | |
.subheader {font-size: 24px; color: #333; margin-bottom: 20px; text-align: right;} | |
.stTextInput>div>div>input {border-radius: 10px;} | |
.stTextArea>div>div>textarea {border-radius: 10px;} | |
.stButton>button {border-radius: 10px; background-color: #ff4b4b; color: white; font-weight: bold;} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.markdown('<h1 class="title">π° News Classification & Q&A App</h1>', unsafe_allow_html=True) | |
col1, col2 = st.columns([2, 1]) | |
with col2: | |
# ====================== Component 1: News Classification ====================== # | |
st.markdown('<h2 class="subheader">π Classify News Articles</h2>', unsafe_allow_html=True) | |
st.markdown("Upload a CSV file with a 'content' column to classify news into categories.") | |
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") | |
if uploaded_file is not None: | |
try: | |
df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues | |
except UnicodeDecodeError: | |
df = pd.read_csv(uploaded_file, encoding="ISO-8859-1") | |
if 'content' not in df.columns: | |
st.error("β Error: The uploaded CSV must contain a 'content' column.") | |
else: | |
st.success("β File successfully uploaded!") | |
st.write("Preview of uploaded data:") | |
st.dataframe(df.head()) | |
# Preprocessing function to clean the text | |
def preprocess_text(text): | |
text = text.lower() # Convert to lowercase | |
text = re.sub(r'\s+', ' ', text) # Remove extra spaces | |
text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers | |
return text | |
# Apply preprocessing and classification | |
df['processed_content'] = df['content'].apply(preprocess_text) | |
df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown") | |
# Show results | |
st.markdown("### πΉ Classification Results:") | |
st.dataframe(df[['content', 'class']]) | |
# Provide CSV download | |
output = io.BytesIO() | |
df.to_csv(output, index=False, encoding="utf-8-sig") | |
st.download_button(label="β¬οΈ Download classified news", data=output.getvalue(), file_name="classified_news.csv", mime="text/csv") | |
# ====================== Component 2: Q&A ====================== # | |
st.markdown('<h2 class="subheader">β Ask a Question About the News</h2>', unsafe_allow_html=True) | |
st.markdown("Enter a question and provide a news article to get an answer.") | |
question = st.text_input("π Ask a question:") | |
context = st.text_area("π Provide the news article or content:", height=150) | |
if question and context.strip(): | |
model_name_qa = "distilbert-base-uncased-distilled-squad" | |
qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa) | |
result = qa_pipeline(question=question, context=context) | |
# Display Answer | |
if 'answer' in result and result['answer']: | |
st.markdown(f"### β Answer: {result['answer']}") | |
else: | |
st.markdown("### β No answer found in the provided content.") | |