Sephfox commited on
Commit
0241af2
·
verified ·
1 Parent(s): ff1fffe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -41
app.py CHANGED
@@ -41,8 +41,8 @@ emotions_target = pd.Categorical(df['emotion']).codes
41
  emotion_classes = pd.Categorical(df['emotion']).categories
42
 
43
  # Load pre-trained BERT model for emotion prediction
44
- emotion_prediction_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
45
- emotion_prediction_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
46
 
47
  # Lazy loading for the fine-tuned language model (DialoGPT)
48
  _finetuned_lm_tokenizer = None
@@ -154,19 +154,19 @@ def predict_emotion(context):
154
  emotion_prediction_pipeline = pipeline('text-classification', model=emotion_prediction_model, tokenizer=emotion_prediction_tokenizer, top_k=None)
155
  predictions = emotion_prediction_pipeline(context)
156
  emotion_scores = {prediction['label']: prediction['score'] for prediction in predictions[0]}
157
- predicted_label = max(emotion_scores, key=emotion_scores.get)
158
 
159
- # Map the predicted label to our emotion categories
160
  emotion_mapping = {
161
- 'LABEL_0': 'joy',
162
- 'LABEL_1': 'sadness',
163
- 'LABEL_2': 'anger',
164
- 'LABEL_3': 'fear',
165
- 'LABEL_4': 'surprise',
166
- 'LABEL_5': 'disgust'
167
  }
168
 
169
- return emotion_mapping.get(predicted_label, 'neutral')
170
 
171
  def generate_text(prompt, emotion=None, max_length=100):
172
  finetuned_lm_tokenizer, finetuned_lm_model = get_finetuned_lm_model()
@@ -231,6 +231,7 @@ def reset_emotions():
231
  emotions[emotion]['percentage'] = 10
232
  emotions[emotion]['intensity'] = 0
233
  emotions['ideal_state']['percentage'] = 100
 
234
 
235
  def respond_to_user(user_input, chat_history):
236
  # Predict the emotion from the user input
@@ -261,36 +262,25 @@ def respond_to_user(user_input, chat_history):
261
 
262
  return response, chat_history, get_emotion_summary()
263
 
264
- def chat_interface(user_input, history):
265
- response, updated_history, emotion_summary = respond_to_user(user_input, history)
266
- return response, updated_history, emotion_summary
267
-
268
  # Gradio interface
269
- iface = gr.Interface(
270
- fn=chat_interface,
271
- inputs=[
272
- gr.Textbox(lines=2, placeholder="Type your message here..."),
273
- gr.State([])
274
- ],
275
- outputs=[
276
- gr.Textbox(label="AI Response"),
277
- gr.State(),
278
- gr.Textbox(label="Current Emotional State", lines=10)
279
- ],
280
- title="Emotion-Aware AI Chatbot",
281
- description="Chat with an AI that understands and responds to emotions.",
282
- allow_flagging="never",
283
- theme="default"
284
- )
285
-
286
- # Add a button to reset emotions
287
- reset_button = gr.Button("Reset Emotions")
288
- reset_button.click(
289
- fn=reset_emotions,
290
- inputs=None,
291
- outputs=gr.Textbox(label="Current Emotional State", lines=10),
292
- api_name="reset_emotions"
293
- )
294
 
295
  if __name__ == "__main__":
296
- iface.launch()
 
41
  emotion_classes = pd.Categorical(df['emotion']).categories
42
 
43
  # Load pre-trained BERT model for emotion prediction
44
+ emotion_prediction_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
45
+ emotion_prediction_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
46
 
47
  # Lazy loading for the fine-tuned language model (DialoGPT)
48
  _finetuned_lm_tokenizer = None
 
154
  emotion_prediction_pipeline = pipeline('text-classification', model=emotion_prediction_model, tokenizer=emotion_prediction_tokenizer, top_k=None)
155
  predictions = emotion_prediction_pipeline(context)
156
  emotion_scores = {prediction['label']: prediction['score'] for prediction in predictions[0]}
157
+ predicted_emotion = max(emotion_scores, key=emotion_scores.get)
158
 
159
+ # Map the predicted emotion to our emotion categories
160
  emotion_mapping = {
161
+ 'sadness': 'sadness',
162
+ 'joy': 'joy',
163
+ 'love': 'pleasure',
164
+ 'anger': 'anger',
165
+ 'fear': 'fear',
166
+ 'surprise': 'surprise'
167
  }
168
 
169
+ return emotion_mapping.get(predicted_emotion, 'neutral')
170
 
171
  def generate_text(prompt, emotion=None, max_length=100):
172
  finetuned_lm_tokenizer, finetuned_lm_model = get_finetuned_lm_model()
 
231
  emotions[emotion]['percentage'] = 10
232
  emotions[emotion]['intensity'] = 0
233
  emotions['ideal_state']['percentage'] = 100
234
+ return get_emotion_summary()
235
 
236
  def respond_to_user(user_input, chat_history):
237
  # Predict the emotion from the user input
 
262
 
263
  return response, chat_history, get_emotion_summary()
264
 
 
 
 
 
265
  # Gradio interface
266
+ with gr.Blocks() as demo:
267
+ gr.Markdown("# Emotion-Aware AI Chatbot")
268
+ gr.Markdown("Chat with an AI that understands and responds to emotions.")
269
+
270
+ chatbot = gr.Chatbot()
271
+ msg = gr.Textbox(label="Type your message here...")
272
+ clear = gr.Button("Clear")
273
+
274
+ emotion_state = gr.Textbox(label="Current Emotional State", lines=10)
275
+ reset_button = gr.Button("Reset Emotions")
276
+
277
+ def user(user_message, history):
278
+ response, updated_history, emotion_summary = respond_to_user(user_message, history)
279
+ return "", updated_history, emotion_summary
280
+
281
+ msg.submit(user, [msg, chatbot], [msg, chatbot, emotion_state])
282
+ clear.click(lambda: None, None, chatbot, queue=False)
283
+ reset_button.click(reset_emotions, None, emotion_state, queue=False)
 
 
 
 
 
 
 
284
 
285
  if __name__ == "__main__":
286
+ demo.launch()