Ashendilantha commited on
Commit
77d6801
Β·
verified Β·
1 Parent(s): 0702d1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -198
app.py CHANGED
@@ -1,219 +1,115 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import torch
4
  import re
 
5
  from nltk.corpus import stopwords
6
  from nltk.tokenize import word_tokenize
7
  from nltk.stem import WordNetLemmatizer
8
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
9
- import nltk
 
 
 
 
 
 
 
10
 
11
- # Set page configuration
12
- st.set_page_config(page_title="News Analysis App", layout="wide")
 
13
 
14
- # Download required NLTK resources
15
- @st.cache_resource
16
- def download_nltk_resources():
17
- nltk.download('punkt')
18
- nltk.download('stopwords')
19
- nltk.download('wordnet')
 
 
20
 
21
- download_nltk_resources()
 
22
 
23
- # Initialize preprocessor components
24
- stop_words = set(stopwords.words('english'))
25
- lemmatizer = WordNetLemmatizer()
26
 
27
- # Load the fine-tuned model for classification
28
- @st.cache_resource
29
- def load_classification_model():
30
- model_name = "Oneli/News_Classification" # Replace with your actual model path
31
- tokenizer = AutoTokenizer.from_pretrained(model_name)
32
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
33
- return model, tokenizer
34
 
35
- # Load Q&A pipeline
36
- @st.cache_resource
37
- def load_qa_pipeline():
38
- qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
39
- return qa_pipeline
 
 
 
 
 
 
40
 
41
- # Text preprocessing function
42
  def preprocess_text(text):
43
- if pd.isna(text):
44
- return ""
45
-
46
- # Convert to lowercase
47
  text = text.lower()
48
-
49
- # Remove URLs
50
- text = re.sub(r'http\S+|www\S+|https\S+', '', text)
51
-
52
- # Remove HTML tags
53
- text = re.sub(r'<.*?>', '', text)
54
-
55
- # Remove special characters and numbers
56
- text = re.sub(r'[^a-zA-Z\s]', '', text)
57
-
58
- # Tokenize
59
- tokens = word_tokenize(text)
60
-
61
- # Remove stopwords and lemmatize
62
- cleaned_tokens = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words]
63
-
64
- # Join tokens back into text
65
- cleaned_text = ' '.join(cleaned_tokens)
66
-
67
- return cleaned_text
68
-
69
- # Function to classify news articles (bulk processing)
70
- def classify_news(df, model, tokenizer):
71
- # Preprocess the text
72
- df['cleaned_content'] = df['content'].apply(preprocess_text)
73
-
74
- # Prepare for classification
75
- texts = df['cleaned_content'].tolist()
76
-
77
- # Get predictions
78
- predictions = []
79
- batch_size = 16
80
-
81
- for i in range(0, len(texts), batch_size):
82
- batch_texts = texts[i:i+batch_size]
83
- inputs = tokenizer(batch_texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
84
-
85
- with torch.no_grad():
86
- outputs = model(**inputs)
87
- logits = outputs.logits
88
- batch_predictions = torch.argmax(logits, dim=1).tolist()
89
- predictions.extend(batch_predictions)
90
-
91
- # Map numeric predictions back to class labels
92
- id2label = model.config.id2label
93
- df['class'] = [id2label[pred] for pred in predictions]
94
-
95
- return df
96
 
97
- # Function for single article classification
98
- def classify_single_article(text, model, tokenizer):
99
- # Preprocess the text
100
  cleaned_text = preprocess_text(text)
101
-
102
- # Prepare for classification
103
- inputs = tokenizer(cleaned_text, padding=True, truncation=True, max_length=512, return_tensors="pt")
104
-
105
- with torch.no_grad():
106
- outputs = model(**inputs)
107
- logits = outputs.logits
108
- prediction = torch.argmax(logits, dim=1).item()
109
-
110
- # Map numeric prediction back to class label
111
- id2label = model.config.id2label
112
- category = id2label[prediction]
113
- confidence = torch.nn.functional.softmax(logits, dim=1).max().item() * 100
114
-
115
- return category, round(confidence, 2)
116
 
117
- # Main app
118
- def main():
119
- st.title("News Classifier πŸ“’")
120
-
121
- # Sidebar for navigation
122
- st.sidebar.title("Navigation")
123
- app_mode = st.sidebar.radio("Choose the app mode", ["News Classification", "Question Answering"])
124
-
125
- # Section for Single Article Classification
126
- if app_mode == "News Classification":
127
- st.header("πŸ“° Single Article Classification")
128
- st.write("Enter a news article or upload a CSV file to classify the content.")
129
-
130
- # Text input for single article classification
131
- text_input = st.text_area("Enter News Text", placeholder="Type or paste news content here...")
132
- if st.button("πŸ” Classify"):
133
- if text_input:
134
- # Load classification model
135
- with st.spinner("Loading classification model..."):
136
- model, tokenizer = load_classification_model()
137
 
138
- # Classify the text
139
- with st.spinner("Classifying the article..."):
140
- category, confidence = classify_single_article(text_input, model, tokenizer)
141
- st.write(f"*Predicted Category:* {category}")
142
- st.write(f"*Confidence Level:* {confidence}%")
143
- else:
144
- st.warning("Please enter some text to classify.")
145
-
146
- # File upload for bulk classification
147
- st.subheader("πŸ“‚ Bulk Classification (CSV)")
148
- file_input = st.file_uploader("Upload CSV File", type="csv")
149
- if file_input:
150
- df = pd.read_csv(file_input)
151
-
152
- # Display sample of the data
153
- st.subheader("Sample of uploaded data")
154
- st.dataframe(df.head())
155
-
156
- # Check if the required column exists
157
- if 'content' not in df.columns:
158
- st.error("The CSV file must contain a 'content' column with the news articles text.")
159
- else:
160
- # Load model and tokenizer
161
- with st.spinner("Loading classification model..."):
162
- model, tokenizer = load_classification_model()
163
-
164
- # Classify button
165
- if st.button("Classify Articles"):
166
- with st.spinner("Classifying news articles..."):
167
- # Perform classification
168
- result_df = classify_news(df, model, tokenizer)
169
-
170
- # Display results
171
- st.subheader("Classification Results")
172
- st.dataframe(result_df[['content', 'class']])
173
-
174
- # Save to CSV
175
- csv = result_df.to_csv(index=False)
176
- st.download_button(
177
- label="Download output.csv",
178
- data=csv,
179
- file_name="output.csv",
180
- mime="text/csv"
181
- )
182
-
183
- # Show distribution of classes
184
- st.subheader("Class Distribution")
185
- class_counts = result_df['class'].value_counts()
186
- st.bar_chart(class_counts)
187
-
188
- # Section for Question Answering
189
- elif app_mode == "Question Answering":
190
- st.header("πŸ’¬ AI Chat Assistant")
191
- st.write("Ask questions about news content and get answers using a Q&A model.")
192
-
193
- # Text area for news content
194
- news_content = st.text_area("Paste news article content here:", height=200)
195
-
196
- # Question input
197
- question = st.text_input("Enter your question about the article:")
198
-
199
- if news_content and question:
200
- # Load QA pipeline
201
- with st.spinner("Loading Q&A model..."):
202
- qa_pipeline = load_qa_pipeline()
203
-
204
- # Get answer
205
- if st.button("Get Answer"):
206
- with st.spinner("Finding answer..."):
207
- result = qa_pipeline(question=question, context=news_content)
208
-
209
- # Display results
210
- st.subheader("Answer")
211
- st.write(result["answer"])
212
-
213
- st.subheader("Confidence")
214
- st.progress(float(result["score"]))
215
- st.write(f"Confidence Score: {result['score']:.4f}")
216
 
217
- if __name__ == "__main__":
218
- main()
 
 
 
 
 
 
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import string
4
  import re
5
+ import nltk
6
  from nltk.corpus import stopwords
7
  from nltk.tokenize import word_tokenize
8
  from nltk.stem import WordNetLemmatizer
9
+ from nltk.corpus import wordnet
10
+ from transformers import pipeline
11
+ from PIL import Image
12
+
13
+ # Download necessary NLTK data
14
+ nltk.download("stopwords")
15
+ nltk.download("punkt")
16
+ nltk.download("wordnet")
17
+ nltk.download("averaged_perceptron_tagger")
18
 
19
+ # Load Models
20
+ news_classifier = pipeline("text-classification", model="Oneli/News_Classification")
21
+ qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
22
 
23
+ # Label Mapping
24
+ label_mapping = {
25
+ "LABEL_0": "Business",
26
+ "LABEL_1": "Opinion",
27
+ "LABEL_2": "Political Gossip",
28
+ "LABEL_3": "Sports",
29
+ "LABEL_4": "World News"
30
+ }
31
 
32
+ # Store classified article for QA
33
+ context_storage = {"context": "", "bulk_context": "", "num_articles": 0}
34
 
35
+ # Preprocessing functions
36
+ def remove_punctuation(text):
37
+ return text.translate(str.maketrans('', '', string.punctuation))
38
 
39
+ def remove_special_characters(text):
40
+ return re.sub(r'[^A-Za-z\s]', '', text)
 
 
 
 
 
41
 
42
+ def remove_stopwords(text):
43
+ stop_words = set(stopwords.words('english'))
44
+ return " ".join([word for word in text.split() if word not in stop_words])
45
+
46
+ def tokenize_text(text):
47
+ return word_tokenize(text)
48
+
49
+ def lemmatize_tokens(tokens):
50
+ lemmatizer = WordNetLemmatizer()
51
+ wordnet_map = {"N": wordnet.NOUN, 'V': wordnet.VERB, 'J': wordnet.ADJ, 'R': wordnet.ADV}
52
+ return [lemmatizer.lemmatize(token, wordnet_map.get(nltk.pos_tag([token])[0][1][0].upper(), wordnet.NOUN)) for token in tokens]
53
 
 
54
  def preprocess_text(text):
 
 
 
 
55
  text = text.lower()
56
+ text = remove_punctuation(text)
57
+ text = remove_special_characters(text)
58
+ text = remove_stopwords(text)
59
+ tokens = tokenize_text(text)
60
+ tokens = lemmatize_tokens(tokens)
61
+ return " ".join(tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # Classification functions
64
+ def classify_text(text):
 
65
  cleaned_text = preprocess_text(text)
66
+ result = news_classifier(cleaned_text)[0]
67
+ category = label_mapping.get(result['label'], "Unknown")
68
+ confidence = round(result['score'] * 100, 2)
69
+ context_storage["context"] = cleaned_text
70
+ return category, f"Confidence: {confidence}%"
 
 
 
 
 
 
 
 
 
 
71
 
72
+ def classify_csv(file):
73
+ try:
74
+ df = pd.read_csv(file, encoding="utf-8")
75
+ text_column = df.columns[0]
76
+ df["Cleaned_Text"] = df[text_column].astype(str).apply(preprocess_text)
77
+ df["Encoded Prediction"] = df["Cleaned_Text"].apply(lambda x: news_classifier(x)[0]['label'])
78
+ df["Decoded Prediction"] = df["Encoded Prediction"].map(label_mapping)
79
+ df["Confidence"] = df["Cleaned_Text"].apply(lambda x: round(news_classifier(x)[0]['score'] * 100, 2))
80
+ context_storage["bulk_context"] = " ".join(df["Cleaned_Text"].dropna().tolist())
81
+ context_storage["num_articles"] = len(df)
82
+ output_file = "output.csv"
83
+ df.to_csv(output_file, index=False)
84
+ return df, output_file
85
+ except Exception as e:
86
+ return None, f"Error: {str(e)}"
 
 
 
 
 
87
 
88
+ # Streamlit App
89
+ st.set_page_config(page_title="News Classifier", page_icon="πŸ“°")
90
+ st.image("cover.png", caption="News Classifier πŸ“’", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ st.subheader("πŸ“° Single Article Classification")
93
+ text_input = st.text_area("Enter News Text", placeholder="Type or paste news content here...")
94
+ if st.button("πŸ” Classify"):
95
+ if text_input:
96
+ category, confidence = classify_text(text_input)
97
+ st.write(f"*Predicted Category:* {category}")
98
+ st.write(f"*Confidence Level:* {confidence}")
99
+ else:
100
+ st.warning("Please enter some text to classify.")
101
 
102
+ st.subheader("πŸ“‚ Bulk Classification (CSV)")
103
+ file_input = st.file_uploader("Upload CSV File", type="csv")
104
+ if file_input:
105
+ df, output_file = classify_csv(file_input)
106
+ if df is not None:
107
+ st.dataframe(df)
108
+ st.download_button(
109
+ label="Download Processed CSV",
110
+ data=open(output_file, 'rb').read(),
111
+ file_name=output_file,
112
+ mime="text/csv"
113
+ )
114
+ else:
115
+ st.error(f"Error processing file: {output_file}")