Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
import io | |
# Load pre-trained model and tokenizer for text classification | |
model_name = "TAgroup5/news-classification-model" | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Initialize the text classification pipeline | |
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
# Initialize the question answering pipeline | |
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer) | |
# Streamlit App Layout | |
st.title("News Classification and Q&A") | |
# Component 1: Text Classification Pipeline | |
st.header("Classify News Articles") | |
st.markdown(""" | |
Upload a CSV file containing news articles, and the model will classify each article | |
into one of the following categories: Business, Opinion, Political Gossip, Sports, or World News. | |
""") | |
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") | |
if uploaded_file is not None: | |
df = pd.read_csv(uploaded_file) | |
if 'content' not in df.columns: | |
st.error("The uploaded CSV file must have a 'content' column containing news excerpts.") | |
else: | |
st.write("Preview of the data:") | |
st.dataframe(df.head()) | |
# Preprocess the data and classify each article | |
def preprocess_text(text): | |
# Apply necessary preprocessing steps here (e.g., removing stopwords, special characters, etc.) | |
return text | |
# Apply preprocessing and classification | |
df['processed_content'] = df['content'].apply(preprocess_text) | |
df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label']) | |
# Show the results | |
st.write("Classification Results:") | |
st.dataframe(df[['content', 'class']]) | |
# Provide an option to download the output as CSV | |
output = io.StringIO() | |
df.to_csv(output, index=False) | |
st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv") | |
# Component 2: Q&A Pipeline | |
st.header("Ask a Question About the News") | |
st.markdown(""" | |
Type in a question, and the model will extract an answer from the provided news content. | |
""") | |
question = st.text_input("Ask a question:") | |
if question: | |
context = st.text_area("Provide the news article or content for the Q&A:", height=150) | |
if context: | |
# Perform the question-answering task | |
result = qa_pipeline(question=question, context=context) | |
st.write("Answer:", result['answer']) | |