TAgroup5 commited on
Commit
9ff5a0e
Β·
verified Β·
1 Parent(s): 2ec8bb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -85
app.py CHANGED
@@ -5,8 +5,9 @@ import io
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers import AutoModelForQuestionAnswering
7
 
8
- # Load fine-tuned models and tokenizers for both functions
9
- model_name_classification = "TAgroup5/news-classification-model"
 
10
  model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
11
  tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
12
 
@@ -15,90 +16,69 @@ model_qa = AutoModelForQuestionAnswering.from_pretrained(model_name_qa)
15
  tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)
16
 
17
  # Initialize pipelines
18
- from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
 
19
 
20
- model_name = "distilbert-base-cased-distilled-squad" # Example model
21
- model = AutoModelForQuestionAnswering.from_pretrained(model_name)
22
- tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
24
- qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
 
25
 
26
- text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
27
- qa_pipeline = pipeline("question-answering", model=model)
28
-
29
- # Streamlit App
30
- st.set_page_config(page_title="News Classification & Q&A", page_icon="πŸ“°", layout="wide")
31
- st.markdown(
32
- """
33
- <style>
34
- body {background-color: #f4f4f4;}
35
- .title {text-align: center; font-size: 36px; font-weight: bold; color: #ff4b4b;}
36
- .subheader {font-size: 24px; color: #333; margin-bottom: 20px; text-align: right;}
37
- .stTextInput>div>div>input {border-radius: 10px;}
38
- .stTextArea>div>div>textarea {border-radius: 10px;}
39
- .stButton>button {border-radius: 10px; background-color: #ff4b4b; color: white; font-weight: bold;}
40
- </style>
41
- """,
42
- unsafe_allow_html=True,
43
- )
44
-
45
- st.markdown('<h1 class="title">πŸ“° News Classification & Q&A App</h1>', unsafe_allow_html=True)
46
-
47
- col1, col2 = st.columns([2, 1])
48
- with col2:
49
- # ====================== Component 1: News Classification ====================== #
50
- st.markdown('<h2 class="subheader">πŸ“Œ Classify News Articles</h2>', unsafe_allow_html=True)
51
- st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
52
-
53
- uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
54
-
55
- if uploaded_file is not None:
56
- try:
57
- df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues
58
- except UnicodeDecodeError:
59
- df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
60
-
61
- if 'content' not in df.columns:
62
- st.error("❌ Error: The uploaded CSV must contain a 'content' column.")
63
- else:
64
- st.success("βœ… File successfully uploaded!")
65
- st.write("Preview of uploaded data:")
66
- st.dataframe(df.head())
67
-
68
- # Preprocessing function to clean the text
69
- def preprocess_text(text):
70
- text = text.lower() # Convert to lowercase
71
- text = re.sub(r'\s+', ' ', text) # Remove extra spaces
72
- text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
73
- return text
74
-
75
- # Apply preprocessing and classification
76
- df['processed_content'] = df['content'].apply(preprocess_text)
77
- df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
78
-
79
- # Show results
80
- st.markdown("### πŸ”Ή Classification Results:")
81
- st.dataframe(df[['content', 'class']])
82
-
83
- # Provide CSV download
84
- output = io.BytesIO()
85
- df.to_csv(output, index=False, encoding="utf-8-sig")
86
- st.download_button(label="⬇️ Download classified news", data=output.getvalue(), file_name="classified_news.csv", mime="text/csv")
87
-
88
- # ====================== Component 2: Q&A ====================== #
89
- st.markdown('<h2 class="subheader">❓ Ask a Question About the News</h2>', unsafe_allow_html=True)
90
- st.markdown("Enter a question and provide a news article to get an answer.")
91
-
92
- question = st.text_input("πŸ” Ask a question:")
93
- context = st.text_area("πŸ“ Provide the news article or content:", height=150)
94
-
95
- if question and context.strip():
96
- model_name_qa = "distilbert-base-uncased-distilled-squad"
97
- qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
98
- result = qa_pipeline(question=question, context=context)
99
 
100
- # Display Answer
101
- if 'answer' in result and result['answer']:
102
- st.markdown(f"### βœ… Answer: {result['answer']}")
103
- else:
104
- st.markdown("### ❌ No answer found in the provided content.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from transformers import AutoModelForQuestionAnswering
7
 
8
+
9
+ # Load fine-tuned models and tokenizers for both functions
10
+ model_name_classification = "TAgroup5/news-classification-model"
11
  model = AutoModelForSequenceClassification.from_pretrained(model_name_classification)
12
  tokenizer = AutoTokenizer.from_pretrained(model_name_classification)
13
 
 
16
  tokenizer_qa = AutoTokenizer.from_pretrained(model_name_qa)
17
 
18
  # Initialize pipelines
19
+ text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
20
+ qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
21
 
 
 
 
22
 
23
+ # Streamlit App
24
+ st.title("News Classification and Q&A")
25
 
26
+ ## ====================== Component 1: News Classification ====================== ##
27
+ st.header("Classify News Articles")
28
+ st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")
29
+
30
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
31
+
32
+ if uploaded_file is not None:
33
+ try:
34
+ df = pd.read_csv(uploaded_file, encoding="utf-8") # Handle encoding issues
35
+ except UnicodeDecodeError:
36
+ df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")
37
+
38
+ if 'content' not in df.columns:
39
+ st.error("Error: The uploaded CSV must contain a 'content' column.")
40
+ else:
41
+ st.write("Preview of uploaded data:")
42
+ st.dataframe(df.head())
43
+
44
+ # Preprocessing function to clean the text
45
+ def preprocess_text(text):
46
+ text = text.lower() # Convert to lowercase
47
+ text = re.sub(r'\s+', ' ', text) # Remove extra spaces
48
+ text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
49
+ # You don't need tokenization here, as the model tokenizer will handle it
50
+ return text
51
+
52
+
53
+ # Apply preprocessing and classification
54
+ df['processed_content'] = df['content'].apply(preprocess_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # Classify each record into one of the five classes
57
+ df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
58
+
59
+ # Show results
60
+ st.write("Classification Results:")
61
+ st.dataframe(df[['content', 'class']])
62
+
63
+ # Provide CSV download
64
+ output = io.BytesIO()
65
+ df.to_csv(output, index=False, encoding="utf-8-sig")
66
+ st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv")
67
+
68
+ ## ====================== Component 2: Q&A ====================== ##
69
+ st.header("Ask a Question About the News")
70
+ st.markdown("Enter a question and provide a news article to get an answer.")
71
+
72
+ question = st.text_input("Ask a question:")
73
+ context = st.text_area("Provide the news article or content for the Q&A:", height=150)
74
+
75
+ if question and context.strip():
76
+ model_name_qa = "distilbert-base-uncased-distilled-squad" # Example of a common Q&A model
77
+ qa_pipeline = pipeline("question-answering", model=model_name_qa, tokenizer=model_name_qa)
78
+ result = qa_pipeline(question=question, context=context)
79
+
80
+ # Check if the result contains an answer
81
+ if 'answer' in result and result['answer']:
82
+ st.write("Answer:", result['answer'])
83
+ else:
84
+ st.write("No answer found in the provided content.")