File size: 3,592 Bytes
246133f
 
5cdc45c
237a63b
5cdc45c
ec68c76
6ad09a0
08acf50
ebaf596
 
 
 
 
 
 
 
246133f
5cdc45c
237a63b
66d1db7
 
246133f
5cdc45c
237a63b
 
5cdc45c
237a63b
5cdc45c
237a63b
 
246133f
 
5cdc45c
 
 
 
 
237a63b
5cdc45c
237a63b
5cdc45c
237a63b
246133f
6ad09a0
237a63b
ff871dc
66d1db7
ff871dc
ff72b25
 
 
5cdc45c
237a63b
 
6ad09a0
 
5cdc45c
 
 
237a63b
 
246133f
5cdc45c
 
 
237a63b
246133f
5cdc45c
237a63b
5cdc45c
237a63b
 
5cdc45c
237a63b
5cdc45c
ebaf596
 
5cdc45c
98f90ad
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
import pandas as pd
import re
import io
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers import AutoModelForQuestionAnswering


# Load fine-tuned models and tokenizers for both functions 
model_name_classification = "TAgroup5/news-classification-model"  # Replace with the correct model name
model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
tokenizer = AutoTokenizer.from_pretrained(model_name_classification)

model_name_qa = "distilbert-base-cased-distilled-squad"
model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa)
tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)

# Initialize pipelines
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)


# Streamlit App
st.title("News Classification and Q&A")

## ====================== Component 1: News Classification ====================== ##
st.header("Classify News Articles")
st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")

uploaded_file = st.file_uploader("Choose a CSV file", type="csv")

if uploaded_file is not None:
    try:
        df = pd.read_csv(uploaded_file, encoding="utf-8")  # Handle encoding issues
    except UnicodeDecodeError:
        df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")

    if 'content' not in df.columns:
        st.error("Error: The uploaded CSV must contain a 'content' column.")
    else:
        st.write("Preview of uploaded data:")
        st.dataframe(df.head())

        # Preprocessing function to clean the text
        def preprocess_text(text):
            text = text.lower()  # Convert to lowercase
            text = re.sub(r'\s+', ' ', text)  # Remove extra spaces
            text = re.sub(r'[^a-z\s]', '', text)  # Remove special characters & numbers
            # You don't need tokenization here, as the model tokenizer will handle it
            return text


        # Apply preprocessing and classification
        df['processed_content'] = df['content'].apply(preprocess_text)
        
        # Classify each record into one of the five classes
        df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")

        # Show results
        st.write("Classification Results:")
        st.dataframe(df[['content', 'class']])

        # Provide CSV download
        output = io.BytesIO()
        df.to_csv(output, index=False, encoding="utf-8-sig")
        st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv")

## ====================== Component 2: Q&A ====================== ##
st.header("Ask a Question About the News")
st.markdown("Enter a question and provide a news article to get an answer.")

question = st.text_input("Ask a question:")
context = st.text_area("Provide the news article or content for the Q&A:", height=150)

if question and context.strip():
    model_name_qa = "distilbert-base-uncased-distilled-squad"  # Example of a common Q&A model
    qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
    result = qa_pipeline(question=question, context=context)
    
    # Check if the result contains an answer
    if 'answer' in result and result['answer']:
        st.write("Answer:", result['answer'])
    else:
        st.write("No answer found in the provided content.")