TAgroup5 commited on
Commit
73c0f99
Β·
verified Β·
1 Parent(s): b704849

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -40
app.py CHANGED
@@ -4,81 +4,110 @@ import re
4
  import io
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers import AutoModelForQuestionAnswering
7
-
8
-
9
- # Load fine-tuned models and tokenizers for both functions
10
- model_name_classification = "TAgroup5/news-classification-model" # Replace with the correct model name
11
- model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
12
- tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
13
-
14
- model_name_qa = "distilbert-base-cased-distilled-squad"
15
- model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa)
16
- tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)
17
-
18
- # Initialize pipelines
19
- text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
20
- qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
21
-
22
-
23
- # Streamlit App
24
- st.title("News Classification and Q&A")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  ## ====================== Component 1: News Classification ====================== ##
27
- st.header("Classify News Articles")
28
  st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
29
 
30
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
31
 
32
  if uploaded_file is not None:
33
  try:
34
- df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues
35
  except UnicodeDecodeError:
36
  df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
37
 
38
  if 'content' not in df.columns:
39
- st.error("Error: The uploaded CSV must contain a 'content' column.")
40
  else:
 
41
  st.write("Preview of uploaded data:")
42
  st.dataframe(df.head())
43
 
44
- # Preprocessing function to clean the text
45
  def preprocess_text(text):
46
- text = text.lower() # Convert to lowercase
47
- text = re.sub(r'\s+', ' ', text) # Remove extra spaces
48
- text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
49
- # You don't need tokenization here, as the model tokenizer will handle it
50
  return text
51
 
52
-
53
- # Apply preprocessing and classification
54
  df['processed_content'] = df['content'].apply(preprocess_text)
55
-
56
- # Classify each record into one of the five classes
 
 
 
 
 
 
57
  df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
58
 
59
- # Show results
60
- st.write("Classification Results:")
61
  st.dataframe(df[['content', 'class']])
62
 
63
  # Provide CSV download
64
  output = io.BytesIO()
65
  df.to_csv(output, index=False, encoding="utf-8-sig")
66
- st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv")
67
 
68
  ## ====================== Component 2: Q&A ====================== ##
69
- st.header("Ask a Question About the News")
70
  st.markdown("Enter a question and provide a news article to get an answer.")
71
 
72
- question = st.text_input("Ask a question:")
73
- context = st.text_area("Provide the news article or content for the Q&A:", height=150)
74
 
75
  if question and context.strip():
76
- model_name_qa = "distilbert-base-uncased-distilled-squad" # Example of a common Q&A model
77
  qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
78
  result = qa_pipeline(question=question, context=context)
79
 
80
- # Check if the result contains an answer
81
  if 'answer' in result and result['answer']:
82
- st.write("Answer:", result['answer'])
83
  else:
84
- st.write("No answer found in the provided content.")
 
4
  import io
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers import AutoModelForQuestionAnswering
7
+ from streamlit_extras.app_logo import add_logo # For adding a logo
8
+
9
+ # Custom Styling
10
+ st.set_page_config(page_title="News Classifier & Q&A", page_icon="πŸ“°", layout="wide")
11
+
12
+ # CSS for styling
13
+ st.markdown(
14
+ """
15
+ <style>
16
+ body {
17
+ background-color: #f5f5f5;
18
+ }
19
+ .stApp {
20
+ background-color: white;
21
+ border-radius: 10px;
22
+ padding: 20px;
23
+ box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.1);
24
+ }
25
+ .stTitle, .stHeader {
26
+ color: #0073e6;
27
+ text-align: center;
28
+ }
29
+ .stButton>button {
30
+ background-color: #0073e6 !important;
31
+ color: white !important;
32
+ border-radius: 8px !important;
33
+ font-size: 16px !important;
34
+ }
35
+ .stDownloadButton>button {
36
+ background-color: #28a745 !important;
37
+ color: white !important;
38
+ border-radius: 8px !important;
39
+ }
40
+ </style>
41
+ """,
42
+ unsafe_allow_html=True,
43
+ )
44
+
45
+ # Add a logo (optional, replace with your logo URL)
46
+ # add_logo("https://your-logo-url.png", height=50)
47
+
48
+ st.title("πŸ“° News Classification & Q&A")
49
 
50
  ## ====================== Component 1: News Classification ====================== ##
51
+ st.header("πŸ“Œ Classify News Articles")
52
  st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
53
 
54
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
55
 
56
  if uploaded_file is not None:
57
  try:
58
+ df = pd.read_csv(uploaded_file, encoding="utf-8")
59
  except UnicodeDecodeError:
60
  df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
61
 
62
  if 'content' not in df.columns:
63
+ st.error("❌ The uploaded CSV must contain a 'content' column.")
64
  else:
65
+ st.success("βœ… File uploaded successfully!")
66
  st.write("Preview of uploaded data:")
67
  st.dataframe(df.head())
68
 
69
+ # Preprocessing function
70
  def preprocess_text(text):
71
+ text = text.lower()
72
+ text = re.sub(r'\s+', ' ', text)
73
+ text = re.sub(r'[^a-z\s]', '', text)
 
74
  return text
75
 
76
+ # Apply preprocessing
 
77
  df['processed_content'] = df['content'].apply(preprocess_text)
78
+
79
+ # Load Model
80
+ model_name_classification = "TAgroup5/news-classification-model"
81
+ model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
82
+ tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
83
+ text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
84
+
85
+ # Classify each record
86
  df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
87
 
88
+ # Display results
89
+ st.write("πŸ“Œ Classification Results:")
90
  st.dataframe(df[['content', 'class']])
91
 
92
  # Provide CSV download
93
  output = io.BytesIO()
94
  df.to_csv(output, index=False, encoding="utf-8-sig")
95
+ st.download_button(label="πŸ“₯ Download Classified News", data=output.getvalue(), file_name="classified_news.csv", mime="text/csv")
96
 
97
  ## ====================== Component 2: Q&A ====================== ##
98
+ st.header("πŸ’¬ Ask a Question About the News")
99
  st.markdown("Enter a question and provide a news article to get an answer.")
100
 
101
+ question = st.text_input("πŸ” Ask a question:")
102
+ context = st.text_area("πŸ“ Provide the news article content:", height=150)
103
 
104
  if question and context.strip():
105
+ model_name_qa = "distilbert-base-uncased-distilled-squad"
106
  qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
107
  result = qa_pipeline(question=question, context=context)
108
 
109
+ # Display answer
110
  if 'answer' in result and result['answer']:
111
+ st.success(f"**πŸ—£ Answer:** {result['answer']}")
112
  else:
113
+ st.warning("⚠️ No answer found in the provided content.")