TAgroup5 commited on
Commit
5cdc45c
·
verified ·
1 Parent(s): b9a2eea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -41
app.py CHANGED
@@ -1,75 +1,66 @@
1
  import streamlit as st
2
  import pandas as pd
3
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
4
  import io
 
5
 
6
- # Load pre-trained model and tokenizer for text classification
7
- model_name = "TAgroup5/news-classification-model"
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
- # Initialize the text classification pipeline
12
  text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
13
-
14
- # Initialize the question answering pipeline
15
  qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
16
 
17
- # Streamlit App Layout
18
  st.title("News Classification and Q&A")
19
 
20
- # Component 1: Text Classification Pipeline
21
  st.header("Classify News Articles")
22
-
23
- st.markdown("""
24
- Upload a CSV file containing news articles, and the model will classify each article
25
- into one of the following categories: Business, Opinion, Political Gossip, Sports, or World News.
26
- """)
27
 
28
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
29
 
30
  if uploaded_file is not None:
31
- df = pd.read_csv(uploaded_file)
32
-
 
 
 
33
  if 'content' not in df.columns:
34
- st.error("The uploaded CSV file must have a 'content' column containing news excerpts.")
35
  else:
36
- st.write("Preview of the data:")
37
  st.dataframe(df.head())
38
 
39
- # Preprocess the data and classify each article
40
  def preprocess_text(text):
41
- # Apply necessary preprocessing steps here (e.g., removing stopwords, special characters, etc.)
 
 
42
  return text
43
-
44
  # Apply preprocessing and classification
45
  df['processed_content'] = df['content'].apply(preprocess_text)
46
- df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'])
47
-
48
- # Show the results
49
  st.write("Classification Results:")
50
  st.dataframe(df[['content', 'class']])
51
 
52
- # Provide an option to download the output as CSV
53
- output = io.StringIO()
54
- df.to_csv(output, index=False)
55
  st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv")
56
 
57
-
58
- # Component 2: Q&A Pipeline
59
  st.header("Ask a Question About the News")
60
-
61
- st.markdown("""
62
- Type in a question, and the model will extract an answer from the provided news content.
63
- """)
64
 
65
  question = st.text_input("Ask a question:")
 
66
 
67
- if question:
68
- context = st.text_area("Provide the news article or content for the Q&A:", height=150)
69
-
70
- if context:
71
- # Perform the question-answering task
72
- result = qa_pipeline(question=question, context=context)
73
-
74
- st.write("Answer:", result['answer'])
75
-
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import re
4
  import io
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
 
7
+ # Load fine-tuned model and tokenizer
8
+ model_name = "TAgroup5/daily-mirror-news-classifier"
9
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
 
12
+ # Initialize pipelines
13
  text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
 
 
14
  qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
15
 
16
+ # Streamlit App
17
  st.title("News Classification and Q&A")
18
 
19
+ ## ====================== Component 1: News Classification ====================== ##
20
  st.header("Classify News Articles")
21
+ st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
 
 
 
 
22
 
23
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
24
 
25
  if uploaded_file is not None:
26
+ try:
27
+ df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues
28
+ except UnicodeDecodeError:
29
+ df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
30
+
31
  if 'content' not in df.columns:
32
+ st.error("Error: The uploaded CSV must contain a 'content' column.")
33
  else:
34
+ st.write("Preview of uploaded data:")
35
  st.dataframe(df.head())
36
 
37
+ # Preprocessing function
38
  def preprocess_text(text):
39
+ text = text.lower() # Ensure consistent casing
40
+ text = re.sub(r'\s+', ' ', text) # Remove extra spaces
41
+ text = re.sub(r'[^a-zA-Z0-9\s]', '', text) # Remove special characters
42
  return text
43
+
44
  # Apply preprocessing and classification
45
  df['processed_content'] = df['content'].apply(preprocess_text)
46
+ df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
47
+
48
+ # Show results
49
  st.write("Classification Results:")
50
  st.dataframe(df[['content', 'class']])
51
 
52
+ # Provide CSV download
53
+ output = io.BytesIO()
54
+ df.to_csv(output, index=False, encoding="utf-8-sig")
55
  st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv")
56
 
57
+ ## ====================== Component 2: Q&A ====================== ##
 
58
  st.header("Ask a Question About the News")
59
+ st.markdown("Enter a question and provide a news article to get an answer.")
 
 
 
60
 
61
  question = st.text_input("Ask a question:")
62
+ context = st.text_area("Provide the news article or content for the Q&A:", height=150)
63
 
64
+ if question and context.strip():
65
+ result = qa_pipeline(question=question, context=context)
66
+ st.write("Answer:", result['answer'])