TAgroup5 commited on
Commit
f72fd34
Β·
verified Β·
1 Parent(s): 9ff5a0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -30
app.py CHANGED
@@ -5,9 +5,49 @@ import io
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers import AutoModelForQuestionAnswering
7
 
8
-
9
- # Load fine-tuned models and tokenizers for both functions
10
- model_name_classification = "TAgroup5/news-classification-model"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
12
  tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
13
 
@@ -19,66 +59,59 @@ tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)
19
  text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
20
  qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
21
 
22
-
23
  # Streamlit App
24
- st.title("News Classification and Q&A")
25
 
26
  ## ====================== Component 1: News Classification ====================== ##
27
- st.header("Classify News Articles")
28
- st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
29
 
30
- uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
31
 
32
  if uploaded_file is not None:
33
  try:
34
- df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues
35
  except UnicodeDecodeError:
36
  df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
37
 
38
  if 'content' not in df.columns:
39
- st.error("Error: The uploaded CSV must contain a 'content' column.")
40
  else:
41
- st.write("Preview of uploaded data:")
42
  st.dataframe(df.head())
43
 
44
- # Preprocessing function to clean the text
45
  def preprocess_text(text):
46
- text = text.lower() # Convert to lowercase
47
- text = re.sub(r'\s+', ' ', text) # Remove extra spaces
48
- text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
49
- # You don't need tokenization here, as the model tokenizer will handle it
50
  return text
51
 
52
-
53
- # Apply preprocessing and classification
54
  df['processed_content'] = df['content'].apply(preprocess_text)
55
-
56
- # Classify each record into one of the five classes
57
  df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
58
 
59
  # Show results
60
- st.write("Classification Results:")
61
  st.dataframe(df[['content', 'class']])
62
 
63
  # Provide CSV download
64
  output = io.BytesIO()
65
  df.to_csv(output, index=False, encoding="utf-8-sig")
66
- st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv")
67
 
68
  ## ====================== Component 2: Q&A ====================== ##
69
- st.header("Ask a Question About the News")
70
- st.markdown("Enter a question and provide a news article to get an answer.")
71
 
72
- question = st.text_input("Ask a question:")
73
- context = st.text_area("Provide the news article or content for the Q&A:", height=150)
74
 
75
  if question and context.strip():
76
- model_name_qa = "distilbert-base-uncased-distilled-squad" # Example of a common Q&A model
77
  qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
78
  result = qa_pipeline(question=question, context=context)
79
 
80
- # Check if the result contains an answer
81
  if 'answer' in result and result['answer']:
82
- st.write("Answer:", result['answer'])
83
  else:
84
- st.write("No answer found in the provided content.")
 
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers import AutoModelForQuestionAnswering
7
 
8
+ # Streamlit UI
9
+ st.set_page_config(page_title="News Classifier & Q&A", layout="wide")
10
+ st.markdown("""
11
+ <style>
12
+ body {
13
+ background-color: #f4f4f4;
14
+ color: #333333;
15
+ font-family: 'Arial', sans-serif;
16
+ }
17
+ .stApp {
18
+ background-color: white;
19
+ padding: 20px;
20
+ border-radius: 10px;
21
+ box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.1);
22
+ }
23
+ h1, h2 {
24
+ color: #ff4b4b;
25
+ }
26
+ .stButton>button {
27
+ background-color: #ff4b4b !important;
28
+ color: white;
29
+ font-size: 16px;
30
+ border-radius: 5px;
31
+ }
32
+ .stDownloadButton>button {
33
+ background-color: #28a745 !important;
34
+ color: white;
35
+ font-size: 16px;
36
+ border-radius: 5px;
37
+ }
38
+ .stTextInput>div>div>input {
39
+ border-radius: 5px;
40
+ border: 1px solid #ccc;
41
+ }
42
+ .stTextArea>div>textarea {
43
+ border-radius: 5px;
44
+ border: 1px solid #ccc;
45
+ }
46
+ </style>
47
+ """, unsafe_allow_html=True)
48
+
49
+ # Load fine-tuned models and tokenizers
50
+ model_name_classification = "TAgroup5/news-classification-model"
51
  model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
52
  tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
53
 
 
59
  text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
60
  qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
61
 
 
62
  # Streamlit App
63
+ st.title("πŸ“° News Classification and Q&A πŸ€–")
64
 
65
  ## ====================== Component 1: News Classification ====================== ##
66
+ st.header("πŸ“Œ Classify News Articles")
67
+ st.markdown("Upload a CSV file with a **'content'** column to classify news into categories.")
68
 
69
+ uploaded_file = st.file_uploader("πŸ“‚ Choose a CSV file", type="csv")
70
 
71
  if uploaded_file is not None:
72
  try:
73
+ df = pd.read_csv(uploaded_file, encoding="utf-8")
74
  except UnicodeDecodeError:
75
  df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
76
 
77
  if 'content' not in df.columns:
78
+ st.error("❌ Error: The uploaded CSV must contain a 'content' column.")
79
  else:
80
+ st.write("βœ… Preview of uploaded data:")
81
  st.dataframe(df.head())
82
 
83
+ # Preprocessing function
84
  def preprocess_text(text):
85
+ text = text.lower()
86
+ text = re.sub(r'\s+', ' ', text)
87
+ text = re.sub(r'[^a-z\s]', '', text)
 
88
  return text
89
 
 
 
90
  df['processed_content'] = df['content'].apply(preprocess_text)
 
 
91
  df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
92
 
93
  # Show results
94
+ st.write("πŸ” Classification Results:")
95
  st.dataframe(df[['content', 'class']])
96
 
97
  # Provide CSV download
98
  output = io.BytesIO()
99
  df.to_csv(output, index=False, encoding="utf-8-sig")
100
+ st.download_button(label="πŸ“₯ Download Classified News", data=output.getvalue(), file_name="classified_news.csv", mime="text/csv")
101
 
102
  ## ====================== Component 2: Q&A ====================== ##
103
+ st.header("πŸ’¬ Ask a Question About the News")
104
+ st.markdown("Enter a question and provide a news article to get an AI-generated answer.")
105
 
106
+ question = st.text_input("❓ Ask a question:")
107
+ context = st.text_area("πŸ“° Provide the news article or content:", height=150)
108
 
109
  if question and context.strip():
110
+ model_name_qa = "distilbert-base-uncased-distilled-squad"
111
  qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
112
  result = qa_pipeline(question=question, context=context)
113
 
 
114
  if 'answer' in result and result['answer']:
115
+ st.success(f"βœ… Answer: {result['answer']}")
116
  else:
117
+ st.warning("⚠️ No answer found in the provided content.")