Ashendilantha commited on
Commit
e8f4b00
·
verified ·
1 Parent(s): 8dbdc8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -127
app.py CHANGED
@@ -1,156 +1,137 @@
1
  import streamlit as st
2
  import pandas as pd
 
3
  import re
4
  import nltk
5
- from nltk.tokenize import word_tokenize
6
  from nltk.corpus import stopwords
 
7
  from nltk.stem import WordNetLemmatizer
8
- from transformers import pipeline
9
- from PIL import Image
10
 
11
- # Ensure NLTK resources are downloaded correctly
12
- nltk.download('stopwords')
13
- nltk.download('punkt')
14
 
15
- # Load Models
16
- news_classifier = pipeline("text-classification", model="Oneli/News_Classification")
 
 
 
 
17
 
18
- # Preprocessing Function
19
- lemmatizer = WordNetLemmatizer()
 
20
  stop_words = set(stopwords.words('english'))
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def preprocess_text(text):
23
  if pd.isna(text):
24
  return ""
25
 
26
- # Convert to lowercase
27
  text = text.lower()
28
-
29
- # Remove URLs
30
  text = re.sub(r'http\S+|www\S+|https\S+', '', text)
31
-
32
- # Remove HTML tags
33
  text = re.sub(r'<.*?>', '', text)
34
-
35
- # Remove special characters and numbers
36
  text = re.sub(r'[^a-zA-Z\s]', '', text)
37
-
38
- # Tokenize
39
  tokens = word_tokenize(text)
40
-
41
- # Remove stopwords and lemmatize
42
  cleaned_tokens = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words]
43
-
44
- # Join tokens back into text
45
  cleaned_text = ' '.join(cleaned_tokens)
46
-
47
  return cleaned_text
48
 
49
- # Load Cover Image
50
- cover_image = Image.open("cover.png") # Ensure this image exists
51
-
52
- # Label Mapping
53
- label_mapping = {
54
- "LABEL_0": "Business",
55
- "LABEL_1": "Opinion",
56
- "LABEL_2": "Political Gossip",
57
- "LABEL_3": "Sports",
58
- "LABEL_4": "World News"
59
- }
60
-
61
- # Store classified article for QA
62
- context_storage = {"context": "", "bulk_context": "", "num_articles": 0}
63
-
64
- # Function for Single Article Classification
65
- def classify_text(text):
66
- text = preprocess_text(text) # Preprocess text
67
- result = news_classifier(text)[0]
68
- category = label_mapping.get(result['label'], "Unknown")
69
- confidence = round(result['score'] * 100, 2)
70
 
71
- # Store context for QA
72
- context_storage["context"] = text
73
-
74
- return category, f"Confidence: {confidence}%"
75
-
76
- # Function for Bulk Classification
77
- def classify_csv(file_path):
78
- try:
79
- df = pd.read_csv(file_path, encoding="utf-8")
80
-
81
- # Automatically detect the column containing text
82
- text_column = df.columns[0] # Assume first column is the text column
83
-
84
- df["Encoded Prediction"] = df[text_column].apply(lambda x: news_classifier(preprocess_text(str(x)))[0]['label'])
85
- df["Decoded Prediction"] = df["Encoded Prediction"].map(label_mapping)
86
- df["Confidence"] = df[text_column].apply(lambda x: round(news_classifier(preprocess_text(str(x)))[0]['score'] * 100, 2))
87
-
88
- # Store all text as a single context for QA
89
- context_storage["bulk_context"] = " ".join(df[text_column].dropna().astype(str).tolist())
90
- context_storage["num_articles"] = len(df)
91
-
92
- output_file = "output.csv"
93
- df.to_csv(output_file, index=False)
94
-
95
- return df, output_file
96
- except Exception as e:
97
- return None, f"Error: {str(e)}"
98
-
99
- # Function to Load Q&A Pipeline
100
- def load_qa_pipeline():
101
- return pipeline("question-answering", model="deepset/roberta-base-squad2")
102
-
103
- # Streamlit App Layout
104
- st.set_page_config(page_title="News Classifier", page_icon="📰")
105
-
106
- # Load and display the cover image
107
- st.image(cover_image, caption="News Classifier 📢", use_container_width=True)
108
-
109
- # Section for Single Article Classification
110
- st.subheader("📰 Single Article Classification")
111
- text_input = st.text_area("Enter News Text", placeholder="Type or paste news content here...")
112
- if st.button("🔍 Classify"):
113
- if text_input:
114
- category, confidence = classify_text(text_input)
115
- st.write(f"**Predicted Category:** {category}")
116
- st.write(f"**Confidence Level:** {confidence}")
117
- else:
118
- st.warning("Please enter some text to classify.")
119
-
120
- # Section for Bulk CSV Classification
121
- st.subheader("📂 Bulk Classification (CSV)")
122
- file_input = st.file_uploader("Upload CSV File", type="csv")
123
- if file_input:
124
- df, output_file = classify_csv(file_input)
125
- if df is not None:
126
- st.dataframe(df)
127
- st.download_button(
128
- label="Download Processed CSV",
129
- data=open(output_file, 'rb').read(),
130
- file_name=output_file,
131
- mime="text/csv"
132
- )
133
- else:
134
- st.error(f"Error processing file: {output_file}")
135
-
136
- # Section for Q&A
137
- st.subheader("💬 Q&A Model")
138
- question = st.text_input("Ask a question about the news article:", placeholder="Ask anything related to the news...")
139
- if question:
140
- # Load the QA model and get the answer
141
- with st.spinner("Loading Q&A model..."):
142
- qa_pipeline = load_qa_pipeline()
143
 
144
- if st.button("Get Answer"):
145
- with st.spinner("Finding answer..."):
146
- result = qa_pipeline(question=question, context=context_storage["context"])
147
-
148
- # Display results
149
- st.subheader("Answer")
150
- st.write(result["answer"])
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- st.subheader("Confidence")
153
- st.progress(float(result["score"]))
154
- st.write(f"Confidence Score: {result['score']:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
 
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import numpy as np
4
  import re
5
  import nltk
 
6
  from nltk.corpus import stopwords
7
+ from nltk.tokenize import word_tokenize
8
  from nltk.stem import WordNetLemmatizer
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
11
 
12
+ # Set page configuration
13
+ st.set_page_config(page_title="News Analysis App", layout="wide")
 
14
 
15
+ # Download required NLTK resources
16
+ @st.cache_resource
17
+ def download_nltk_resources():
18
+ nltk.download('punkt')
19
+ nltk.download('stopwords')
20
+ nltk.download('wordnet')
21
 
22
+ download_nltk_resources()
23
+
24
+ # Initialize preprocessor components
25
  stop_words = set(stopwords.words('english'))
26
+ lemmatizer = WordNetLemmatizer()
27
 
28
+ # Load classification model
29
+ @st.cache_resource
30
+ def load_classification_model():
31
+ model_name = "Oneli/News_Classification"
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
34
+ return model, tokenizer
35
+
36
+ # Load Q&A pipeline
37
+ @st.cache_resource
38
+ def load_qa_pipeline():
39
+ qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
40
+ return qa_pipeline
41
+
42
+ # Preprocessing function
43
  def preprocess_text(text):
44
  if pd.isna(text):
45
  return ""
46
 
 
47
  text = text.lower()
 
 
48
  text = re.sub(r'http\S+|www\S+|https\S+', '', text)
 
 
49
  text = re.sub(r'<.*?>', '', text)
 
 
50
  text = re.sub(r'[^a-zA-Z\s]', '', text)
 
 
51
  tokens = word_tokenize(text)
 
 
52
  cleaned_tokens = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words]
 
 
53
  cleaned_text = ' '.join(cleaned_tokens)
 
54
  return cleaned_text
55
 
56
+ # Batch classification function
57
+ def classify_news(df, model, tokenizer):
58
+ df['cleaned_content'] = df['content'].apply(preprocess_text)
59
+ texts = df['cleaned_content'].tolist()
60
+ predictions = []
61
+ batch_size = 16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ for i in range(0, len(texts), batch_size):
64
+ batch_texts = texts[i:i+batch_size]
65
+ inputs = tokenizer(batch_texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
66
+
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+ logits = outputs.logits
70
+ batch_predictions = torch.argmax(logits, dim=1).tolist()
71
+ predictions.extend(batch_predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ id2label = model.config.id2label
74
+ df['class'] = [id2label[pred] for pred in predictions]
75
+ return df
76
+
77
+ # Main app
78
+ def main():
79
+ st.title("News Analysis Application")
80
+ st.sidebar.title("Navigation")
81
+ app_mode = st.sidebar.radio("Choose the app mode", ["News Classification", "Question Answering"])
82
+
83
+ if app_mode == "News Classification":
84
+ st.header("News Article Classification")
85
+ uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
86
+
87
+ if uploaded_file is not None:
88
+ df = pd.read_csv(uploaded_file)
89
+ st.subheader("Sample of uploaded data")
90
+ st.dataframe(df.head())
91
 
92
+ if 'content' not in df.columns:
93
+ st.error("The CSV file must contain a 'content' column.")
94
+ else:
95
+ with st.spinner("Loading model..."):
96
+ model, tokenizer = load_classification_model()
97
+
98
+ if st.button("Classify Articles"):
99
+ with st.spinner("Classifying news articles..."):
100
+ result_df = classify_news(df, model, tokenizer)
101
+ st.subheader("Classification Results")
102
+ st.dataframe(result_df[['content', 'class']])
103
+ csv = result_df.to_csv(index=False)
104
+ st.download_button("Download output.csv", csv, "output.csv", "text/csv")
105
+ st.subheader("Class Distribution")
106
+ st.bar_chart(result_df['class'].value_counts())
107
+
108
+ elif app_mode == "Question Answering":
109
+ st.header("News Article Q&A")
110
+ uploaded_file = st.file_uploader("Upload CSV for Q&A", type="csv")
111
+
112
+ if uploaded_file is not None:
113
+ df = pd.read_csv(uploaded_file)
114
+ if 'content' not in df.columns:
115
+ st.error("The CSV file must contain a 'content' column.")
116
+ else:
117
+ combined_text = " ".join(df['cleaned_content'].dropna().astype(str).tolist())
118
+ question = st.text_input("Enter your question about the news:")
119
+
120
+ if combined_text and question:
121
+ with st.spinner("Loading Q&A model..."):
122
+ qa_pipeline = load_qa_pipeline()
123
+
124
+ if st.button("Get Answer"):
125
+ with st.spinner("Finding answer..."):
126
+ result = qa_pipeline(question=question, context=combined_text)
127
+ st.subheader("Answer")
128
+ st.write(result["answer"])
129
+ st.subheader("Confidence")
130
+ st.progress(float(result["score"]))
131
+ st.write(f"Confidence Score: {result['score']:.4f}")
132
+
133
+ if __name__ == "__main__":
134
+ main()
135
+
136
 
137