Ashendilantha commited on
Commit
42bdc4d
Β·
verified Β·
1 Parent(s): e8f4b00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -36
app.py CHANGED
@@ -8,9 +8,11 @@ from nltk.tokenize import word_tokenize
8
  from nltk.stem import WordNetLemmatizer
9
  import torch
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
 
11
 
12
  # Set page configuration
13
- st.set_page_config(page_title="News Analysis App", layout="wide")
14
 
15
  # Download required NLTK resources
16
  @st.cache_resource
@@ -25,10 +27,10 @@ download_nltk_resources()
25
  stop_words = set(stopwords.words('english'))
26
  lemmatizer = WordNetLemmatizer()
27
 
28
- # Load classification model
29
  @st.cache_resource
30
  def load_classification_model():
31
- model_name = "Oneli/News_Classification"
32
  tokenizer = AutoTokenizer.from_pretrained(model_name)
33
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
34
  return model, tokenizer
@@ -39,24 +41,43 @@ def load_qa_pipeline():
39
  qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
40
  return qa_pipeline
41
 
42
- # Preprocessing function
43
  def preprocess_text(text):
44
  if pd.isna(text):
45
  return ""
46
 
 
47
  text = text.lower()
 
 
48
  text = re.sub(r'http\S+|www\S+|https\S+', '', text)
 
 
49
  text = re.sub(r'<.*?>', '', text)
 
 
50
  text = re.sub(r'[^a-zA-Z\s]', '', text)
 
 
51
  tokens = word_tokenize(text)
 
 
52
  cleaned_tokens = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words]
 
 
53
  cleaned_text = ' '.join(cleaned_tokens)
 
54
  return cleaned_text
55
 
56
- # Batch classification function
57
  def classify_news(df, model, tokenizer):
 
58
  df['cleaned_content'] = df['content'].apply(preprocess_text)
 
 
59
  texts = df['cleaned_content'].tolist()
 
 
60
  predictions = []
61
  batch_size = 16
62
 
@@ -70,68 +91,113 @@ def classify_news(df, model, tokenizer):
70
  batch_predictions = torch.argmax(logits, dim=1).tolist()
71
  predictions.extend(batch_predictions)
72
 
 
73
  id2label = model.config.id2label
74
  df['class'] = [id2label[pred] for pred in predictions]
 
75
  return df
76
 
77
  # Main app
78
  def main():
79
- st.title("News Analysis Application")
 
 
80
  st.sidebar.title("Navigation")
81
  app_mode = st.sidebar.radio("Choose the app mode", ["News Classification", "Question Answering"])
82
 
 
83
  if app_mode == "News Classification":
84
- st.header("News Article Classification")
85
- uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- if uploaded_file is not None:
88
- df = pd.read_csv(uploaded_file)
 
 
 
 
 
89
  st.subheader("Sample of uploaded data")
90
  st.dataframe(df.head())
91
 
 
92
  if 'content' not in df.columns:
93
- st.error("The CSV file must contain a 'content' column.")
94
  else:
95
- with st.spinner("Loading model..."):
 
96
  model, tokenizer = load_classification_model()
97
 
 
98
  if st.button("Classify Articles"):
99
  with st.spinner("Classifying news articles..."):
 
100
  result_df = classify_news(df, model, tokenizer)
 
 
101
  st.subheader("Classification Results")
102
  st.dataframe(result_df[['content', 'class']])
 
 
103
  csv = result_df.to_csv(index=False)
104
- st.download_button("Download output.csv", csv, "output.csv", "text/csv")
 
 
 
 
 
 
 
105
  st.subheader("Class Distribution")
106
- st.bar_chart(result_df['class'].value_counts())
 
107
 
 
108
  elif app_mode == "Question Answering":
109
- st.header("News Article Q&A")
110
- uploaded_file = st.file_uploader("Upload CSV for Q&A", type="csv")
111
 
112
- if uploaded_file is not None:
113
- df = pd.read_csv(uploaded_file)
114
- if 'content' not in df.columns:
115
- st.error("The CSV file must contain a 'content' column.")
116
- else:
117
- combined_text = " ".join(df['cleaned_content'].dropna().astype(str).tolist())
118
- question = st.text_input("Enter your question about the news:")
119
-
120
- if combined_text and question:
121
- with st.spinner("Loading Q&A model..."):
122
- qa_pipeline = load_qa_pipeline()
 
 
 
 
 
 
 
 
123
 
124
- if st.button("Get Answer"):
125
- with st.spinner("Finding answer..."):
126
- result = qa_pipeline(question=question, context=combined_text)
127
- st.subheader("Answer")
128
- st.write(result["answer"])
129
- st.subheader("Confidence")
130
- st.progress(float(result["score"]))
131
- st.write(f"Confidence Score: {result['score']:.4f}")
132
 
133
  if __name__ == "__main__":
134
  main()
135
 
136
 
137
-
 
8
  from nltk.stem import WordNetLemmatizer
9
  import torch
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
11
+ import requests
12
+ from io import BytesIO
13
 
14
  # Set page configuration
15
+ st.set_page_config(page_title="News Classifier", page_icon="πŸ“°")
16
 
17
  # Download required NLTK resources
18
  @st.cache_resource
 
27
  stop_words = set(stopwords.words('english'))
28
  lemmatizer = WordNetLemmatizer()
29
 
30
+ # Load the fine-tuned model for classification
31
  @st.cache_resource
32
  def load_classification_model():
33
+ model_name = "Oneli/News_Classification" # Replace with your actual model path
34
  tokenizer = AutoTokenizer.from_pretrained(model_name)
35
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
36
  return model, tokenizer
 
41
  qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
42
  return qa_pipeline
43
 
44
+ # Text preprocessing function
45
  def preprocess_text(text):
46
  if pd.isna(text):
47
  return ""
48
 
49
+ # Convert to lowercase
50
  text = text.lower()
51
+
52
+ # Remove URLs
53
  text = re.sub(r'http\S+|www\S+|https\S+', '', text)
54
+
55
+ # Remove HTML tags
56
  text = re.sub(r'<.*?>', '', text)
57
+
58
+ # Remove special characters and numbers
59
  text = re.sub(r'[^a-zA-Z\s]', '', text)
60
+
61
+ # Tokenize
62
  tokens = word_tokenize(text)
63
+
64
+ # Remove stopwords and lemmatize
65
  cleaned_tokens = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words]
66
+
67
+ # Join tokens back into text
68
  cleaned_text = ' '.join(cleaned_tokens)
69
+
70
  return cleaned_text
71
 
72
+ # Function to classify news articles with batch processing
73
  def classify_news(df, model, tokenizer):
74
+ # Preprocess the text
75
  df['cleaned_content'] = df['content'].apply(preprocess_text)
76
+
77
+ # Prepare for classification
78
  texts = df['cleaned_content'].tolist()
79
+
80
+ # Get predictions
81
  predictions = []
82
  batch_size = 16
83
 
 
91
  batch_predictions = torch.argmax(logits, dim=1).tolist()
92
  predictions.extend(batch_predictions)
93
 
94
+ # Map numeric predictions back to class labels
95
  id2label = model.config.id2label
96
  df['class'] = [id2label[pred] for pred in predictions]
97
+
98
  return df
99
 
100
  # Main app
101
  def main():
102
+ st.title("News Classifier πŸ“’")
103
+
104
+ # Sidebar for navigation
105
  st.sidebar.title("Navigation")
106
  app_mode = st.sidebar.radio("Choose the app mode", ["News Classification", "Question Answering"])
107
 
108
+ # Section for Single Article Classification
109
  if app_mode == "News Classification":
110
+ st.header("πŸ“° Single Article Classification")
111
+ st.write("Enter a news article or upload a CSV file to classify the content.")
112
+
113
+ # Text input for single article classification
114
+ text_input = st.text_area("Enter News Text", placeholder="Type or paste news content here...")
115
+ if st.button("πŸ” Classify"):
116
+ if text_input:
117
+ # Load classification model
118
+ with st.spinner("Loading classification model..."):
119
+ model, tokenizer = load_classification_model()
120
+
121
+ # Classify the text
122
+ with st.spinner("Classifying the article..."):
123
+ category, confidence = classify_text(text_input, model, tokenizer)
124
+ st.write(f"*Predicted Category:* {category}")
125
+ st.write(f"*Confidence Level:* {confidence}%")
126
+ else:
127
+ st.warning("Please enter some text to classify.")
128
 
129
+ # File upload for bulk classification
130
+ st.subheader("πŸ“‚ Bulk Classification (CSV)")
131
+ file_input = st.file_uploader("Upload CSV File", type="csv")
132
+ if file_input:
133
+ df = pd.read_csv(file_input)
134
+
135
+ # Display sample of the data
136
  st.subheader("Sample of uploaded data")
137
  st.dataframe(df.head())
138
 
139
+ # Check if the required column exists
140
  if 'content' not in df.columns:
141
+ st.error("The CSV file must contain a 'content' column with the news articles text.")
142
  else:
143
+ # Load model and tokenizer
144
+ with st.spinner("Loading classification model..."):
145
  model, tokenizer = load_classification_model()
146
 
147
+ # Classify button
148
  if st.button("Classify Articles"):
149
  with st.spinner("Classifying news articles..."):
150
+ # Perform classification
151
  result_df = classify_news(df, model, tokenizer)
152
+
153
+ # Display results
154
  st.subheader("Classification Results")
155
  st.dataframe(result_df[['content', 'class']])
156
+
157
+ # Save to CSV
158
  csv = result_df.to_csv(index=False)
159
+ st.download_button(
160
+ label="Download output.csv",
161
+ data=csv,
162
+ file_name="output.csv",
163
+ mime="text/csv"
164
+ )
165
+
166
+ # Show distribution of classes
167
  st.subheader("Class Distribution")
168
+ class_counts = result_df['class'].value_counts()
169
+ st.bar_chart(class_counts)
170
 
171
+ # Section for Question Answering
172
  elif app_mode == "Question Answering":
173
+ st.header("πŸ’¬ AI Chat Assistant")
174
+ st.write("Ask questions about news content and get answers using a Q&A model.")
175
 
176
+ # Text area for news content
177
+ news_content = st.text_area("Paste news article content here:", height=200)
178
+
179
+ # Question input
180
+ question = st.text_input("Enter your question about the article:")
181
+
182
+ if news_content and question:
183
+ # Load QA pipeline
184
+ with st.spinner("Loading Q&A model..."):
185
+ qa_pipeline = load_qa_pipeline()
186
+
187
+ # Get answer
188
+ if st.button("Get Answer"):
189
+ with st.spinner("Finding answer..."):
190
+ result = qa_pipeline(question=question, context=news_content)
191
+
192
+ # Display results
193
+ st.subheader("Answer")
194
+ st.write(result["answer"])
195
 
196
+ st.subheader("Confidence")
197
+ st.progress(float(result["score"]))
198
+ st.write(f"Confidence Score: {result['score']:.4f}")
 
 
 
 
 
199
 
200
  if __name__ == "__main__":
201
  main()
202
 
203