TAgroup5 commited on
Commit
48cfda7
Β·
verified Β·
1 Parent(s): ec68c76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -29
app.py CHANGED
@@ -5,8 +5,7 @@ import io
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers import AutoModelForQuestionAnswering
7
 
8
-
9
- # Load fine-tuned models and tokenizers for both functions
10
  model_name_classification = "TAgroup5/news-classification-model" # Replace with the correct model name
11
  model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
12
  tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
@@ -17,14 +16,28 @@ tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)
17
 
18
  # Initialize pipelines
19
  text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
20
- qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
21
-
22
-
23
- # Streamlit App
24
- st.title("News Classification and Q&A")
25
-
26
- ## ====================== Component 1: News Classification ====================== ##
27
- st.header("Classify News Articles")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
29
 
30
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
@@ -34,51 +47,48 @@ if uploaded_file is not None:
34
  df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues
35
  except UnicodeDecodeError:
36
  df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
37
-
38
  if 'content' not in df.columns:
39
- st.error("Error: The uploaded CSV must contain a 'content' column.")
40
  else:
 
41
  st.write("Preview of uploaded data:")
42
  st.dataframe(df.head())
43
-
44
  # Preprocessing function to clean the text
45
  def preprocess_text(text):
46
  text = text.lower() # Convert to lowercase
47
  text = re.sub(r'\s+', ' ', text) # Remove extra spaces
48
  text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
49
- # You don't need tokenization here, as the model tokenizer will handle it
50
  return text
51
 
52
-
53
  # Apply preprocessing and classification
54
  df['processed_content'] = df['content'].apply(preprocess_text)
55
-
56
- # Classify each record into one of the five classes
57
  df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
58
-
59
  # Show results
60
- st.write("Classification Results:")
61
  st.dataframe(df[['content', 'class']])
62
-
63
  # Provide CSV download
64
  output = io.BytesIO()
65
  df.to_csv(output, index=False, encoding="utf-8-sig")
66
- st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv")
67
 
68
- ## ====================== Component 2: Q&A ====================== ##
69
- st.header("Ask a Question About the News")
70
  st.markdown("Enter a question and provide a news article to get an answer.")
71
 
72
- question = st.text_input("Ask a question:")
73
- context = st.text_area("Provide the news article or content for the Q&A:", height=150)
74
 
75
  if question and context.strip():
76
- model_name_qa = "distilbert-base-uncased-distilled-squad" # Example of a common Q&A model
77
  qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
78
  result = qa_pipeline(question=question, context=context)
79
 
80
- # Check if the result contains an answer
81
  if 'answer' in result and result['answer']:
82
- st.write("Answer:", result['answer'])
83
  else:
84
- st.write("No answer found in the provided content.")
 
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers import AutoModelForQuestionAnswering
7
 
8
+ # Load fine-tuned models and tokenizers for both functions
 
9
  model_name_classification = "TAgroup5/news-classification-model" # Replace with the correct model name
10
  model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
11
  tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
 
16
 
17
  # Initialize pipelines
18
  text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
19
+ qa_pipeline = pipeline("question-answering", model=model, tokenizer=model)
20
+
21
+ # Streamlit App Styling
22
+ st.set_page_config(page_title="News Classification & Q&A", page_icon="πŸ“°", layout="wide")
23
+ st.markdown(
24
+ """
25
+ <style>
26
+ body {background-color: #f4f4f4;}
27
+ .title {text-align: center; font-size: 36px; font-weight: bold; color: #ff4b4b;}
28
+ .subheader {text-align: center; font-size: 24px; color: #333; margin-bottom: 20px;}
29
+ .stTextInput>div>div>input {border-radius: 10px;}
30
+ .stTextArea>div>div>textarea {border-radius: 10px;}
31
+ .stButton>button {border-radius: 10px; background-color: #ff4b4b; color: white; font-weight: bold;}
32
+ </style>
33
+ """,
34
+ unsafe_allow_html=True,
35
+ )
36
+
37
+ st.markdown('<h1 class="title">πŸ“° News Classification & Q&A App</h1>', unsafe_allow_html=True)
38
+
39
+ # ====================== Component 1: News Classification ====================== #
40
+ st.markdown('<h2 class="subheader">πŸ“Œ Classify News Articles</h2>', unsafe_allow_html=True)
41
  st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
42
 
43
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
 
47
  df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues
48
  except UnicodeDecodeError:
49
  df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
50
+
51
  if 'content' not in df.columns:
52
+ st.error("❌ Error: The uploaded CSV must contain a 'content' column.")
53
  else:
54
+ st.success("βœ… File successfully uploaded!")
55
  st.write("Preview of uploaded data:")
56
  st.dataframe(df.head())
57
+
58
  # Preprocessing function to clean the text
59
  def preprocess_text(text):
60
  text = text.lower() # Convert to lowercase
61
  text = re.sub(r'\s+', ' ', text) # Remove extra spaces
62
  text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
 
63
  return text
64
 
 
65
  # Apply preprocessing and classification
66
  df['processed_content'] = df['content'].apply(preprocess_text)
 
 
67
  df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
68
+
69
  # Show results
70
+ st.markdown("### πŸ”Ή Classification Results:")
71
  st.dataframe(df[['content', 'class']])
72
+
73
  # Provide CSV download
74
  output = io.BytesIO()
75
  df.to_csv(output, index=False, encoding="utf-8-sig")
76
+ st.download_button(label="⬇️ Download classified news", data=output.getvalue(), file_name="classified_news.csv", mime="text/csv")
77
 
78
+ # ====================== Component 2: Q&A ====================== #
79
+ st.markdown('<h2 class="subheader">❓ Ask a Question About the News</h2>', unsafe_allow_html=True)
80
  st.markdown("Enter a question and provide a news article to get an answer.")
81
 
82
+ question = st.text_input("πŸ” Ask a question:")
83
+ context = st.text_area("πŸ“ Provide the news article or content:", height=150)
84
 
85
  if question and context.strip():
86
+ model_name_qa = "distilbert-base-uncased-distilled-squad"
87
  qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
88
  result = qa_pipeline(question=question, context=context)
89
 
90
+ # Display Answer
91
  if 'answer' in result and result['answer']:
92
+ st.markdown(f"### βœ… Answer: {result['answer']}")
93
  else:
94
+ st.markdown("### ❌ No answer found in the provided content.")