TAgroup5 commited on
Commit
6ad09a0
·
verified ·
1 Parent(s): f3c8efb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -7,27 +7,24 @@ import nltk
7
  from nltk.tokenize import word_tokenize
8
  from nltk.corpus import stopwords
9
  from nltk.stem import WordNetLemmatizer
10
- from transformers import AutoTokenizer
11
  import nltk
 
 
12
  nltk.download('punkt', download_dir='/root/nltk_data')
13
  nltk.download('stopwords', download_dir='/root/nltk_data')
14
  nltk.download('wordnet', download_dir='/root/nltk_data')
15
 
16
-
17
-
18
  # Initialize lemmatizer and stopwords
19
  lemmatizer = WordNetLemmatizer()
20
  stop_words = set(stopwords.words('english'))
21
 
22
- # Load fine-tuned model and tokenizer
23
- model_name = "TAgroup5/news-classification-model"
24
- model = AutoModelForSequenceClassification.from_pretrained("TAgroup5/news-classification-model")
25
- tokenizer = AutoTokenizer.from_pretrained("TAgroup5/news-classification-model")
26
 
27
  # Initialize pipelines
28
  text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
29
- qa_model_name = "distilbert-base-uncased-distilled-squad" # Example of a common Q&A model
30
- qa_pipeline = pipeline("question-answering", model=qa_model_name, tokenizer=qa_model_name)
31
 
32
  # Streamlit App
33
  st.title("News Classification and Q&A")
@@ -50,7 +47,7 @@ if uploaded_file is not None:
50
  st.write("Preview of uploaded data:")
51
  st.dataframe(df.head())
52
 
53
- # Preprocessing function
54
  def preprocess_text(text):
55
  text = text.lower() # Convert to lowercase
56
  text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
@@ -61,6 +58,8 @@ if uploaded_file is not None:
61
 
62
  # Apply preprocessing and classification
63
  df['processed_content'] = df['content'].apply(preprocess_text)
 
 
64
  df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
65
 
66
  # Show results
@@ -80,6 +79,8 @@ question = st.text_input("Ask a question:")
80
  context = st.text_area("Provide the news article or content for the Q&A:", height=150)
81
 
82
  if question and context.strip():
 
 
83
  result = qa_pipeline(question=question, context=context)
84
 
85
  # Check if the result contains an answer
 
7
  from nltk.tokenize import word_tokenize
8
  from nltk.corpus import stopwords
9
  from nltk.stem import WordNetLemmatizer
 
10
  import nltk
11
+
12
+ # Download NLTK resources
13
  nltk.download('punkt', download_dir='/root/nltk_data')
14
  nltk.download('stopwords', download_dir='/root/nltk_data')
15
  nltk.download('wordnet', download_dir='/root/nltk_data')
16
 
 
 
17
  # Initialize lemmatizer and stopwords
18
  lemmatizer = WordNetLemmatizer()
19
  stop_words = set(stopwords.words('english'))
20
 
21
+ # Load fine-tuned model and tokenizer (adjust the model name)
22
+ model_name = "TAgroup5/news-classification-model" # Replace with the correct model name
23
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
 
26
  # Initialize pipelines
27
  text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
 
 
28
 
29
  # Streamlit App
30
  st.title("News Classification and Q&A")
 
47
  st.write("Preview of uploaded data:")
48
  st.dataframe(df.head())
49
 
50
+ # Preprocessing function to clean the text
51
  def preprocess_text(text):
52
  text = text.lower() # Convert to lowercase
53
  text = re.sub(r'[^a-z\s]', '', text) # Remove special characters & numbers
 
58
 
59
  # Apply preprocessing and classification
60
  df['processed_content'] = df['content'].apply(preprocess_text)
61
+
62
+ # Classify each record into one of the five classes
63
  df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")
64
 
65
  # Show results
 
79
  context = st.text_area("Provide the news article or content for the Q&A:", height=150)
80
 
81
  if question and context.strip():
82
+ qa_model_name = "distilbert-base-uncased-distilled-squad" # Example of a common Q&A model
83
+ qa_pipeline = pipeline("question-answering", model=qa_model_name, tokenizer=qa_model_name)
84
  result = qa_pipeline(question=question, context=context)
85
 
86
  # Check if the result contains an answer