File size: 3,592 Bytes
246133f
 
5cdc45c
237a63b
5cdc45c
ec68c76
6ad09a0
b704849
 
ebaf596
 
 
 
 
 
 
246133f
5cdc45c
237a63b
b704849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7096073
b704849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
import pandas as pd
import re
import io
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers import AutoModelForQuestionAnswering


# Load fine-tuned models and tokenizers for both functions 
model_name_classification = "TAgroup5/news-classification-model"  # Replace with the correct model name
model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
tokenizer = AutoTokenizer.from_pretrained(model_name_classification)

model_name_qa = "distilbert-base-cased-distilled-squad"
model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa)
tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)

# Initialize pipelines
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)


# Streamlit App
st.title("News Classification and Q&A")

## ====================== Component 1: News Classification ====================== ##
st.header("Classify News Articles")
st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")

uploaded_file = st.file_uploader("Choose a CSV file", type="csv")

if uploaded_file is not None:
    try:
        df = pd.read_csv(uploaded_file, encoding="utf-8")  # Handle encoding issues
    except UnicodeDecodeError:
        df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")

    if 'content' not in df.columns:
        st.error("Error: The uploaded CSV must contain a 'content' column.")
    else:
        st.write("Preview of uploaded data:")
        st.dataframe(df.head())

        # Preprocessing function to clean the text
        def preprocess_text(text):
            text = text.lower()  # Convert to lowercase
            text = re.sub(r'\s+', ' ', text)  # Remove extra spaces
            text = re.sub(r'[^a-z\s]', '', text)  # Remove special characters & numbers
            # You don't need tokenization here, as the model tokenizer will handle it
            return text


        # Apply preprocessing and classification
        df['processed_content'] = df['content'].apply(preprocess_text)
        
        # Classify each record into one of the five classes
        df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")

        # Show results
        st.write("Classification Results:")
        st.dataframe(df[['content', 'class']])

        # Provide CSV download
        output = io.BytesIO()
        df.to_csv(output, index=False, encoding="utf-8-sig")
        st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv")

## ====================== Component 2: Q&A ====================== ##
st.header("Ask a Question About the News")
st.markdown("Enter a question and provide a news article to get an answer.")

question = st.text_input("Ask a question:")
context = st.text_area("Provide the news article or content for the Q&A:", height=150)

if question and context.strip():
    model_name_qa = "distilbert-base-uncased-distilled-squad"  # Example of a common Q&A model
    qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
    result = qa_pipeline(question=question, context=context)
    
    # Check if the result contains an answer
    if 'answer' in result and result['answer']:
        st.write("Answer:", result['answer'])
    else:
        st.write("No answer found in the provided content.")