Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -41,8 +41,8 @@ emotions_target = pd.Categorical(df['emotion']).codes
|
|
41 |
emotion_classes = pd.Categorical(df['emotion']).categories
|
42 |
|
43 |
# Load pre-trained BERT model for emotion prediction
|
44 |
-
emotion_prediction_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
|
45 |
-
emotion_prediction_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
46 |
|
47 |
# Lazy loading for the fine-tuned language model (DialoGPT)
|
48 |
_finetuned_lm_tokenizer = None
|
@@ -154,19 +154,19 @@ def predict_emotion(context):
|
|
154 |
emotion_prediction_pipeline = pipeline('text-classification', model=emotion_prediction_model, tokenizer=emotion_prediction_tokenizer, top_k=None)
|
155 |
predictions = emotion_prediction_pipeline(context)
|
156 |
emotion_scores = {prediction['label']: prediction['score'] for prediction in predictions[0]}
|
157 |
-
|
158 |
|
159 |
-
# Map the predicted
|
160 |
emotion_mapping = {
|
161 |
-
'
|
162 |
-
'
|
163 |
-
'
|
164 |
-
'
|
165 |
-
'
|
166 |
-
'
|
167 |
}
|
168 |
|
169 |
-
return emotion_mapping.get(
|
170 |
|
171 |
def generate_text(prompt, emotion=None, max_length=100):
|
172 |
finetuned_lm_tokenizer, finetuned_lm_model = get_finetuned_lm_model()
|
@@ -231,6 +231,7 @@ def reset_emotions():
|
|
231 |
emotions[emotion]['percentage'] = 10
|
232 |
emotions[emotion]['intensity'] = 0
|
233 |
emotions['ideal_state']['percentage'] = 100
|
|
|
234 |
|
235 |
def respond_to_user(user_input, chat_history):
|
236 |
# Predict the emotion from the user input
|
@@ -261,36 +262,25 @@ def respond_to_user(user_input, chat_history):
|
|
261 |
|
262 |
return response, chat_history, get_emotion_summary()
|
263 |
|
264 |
-
def chat_interface(user_input, history):
|
265 |
-
response, updated_history, emotion_summary = respond_to_user(user_input, history)
|
266 |
-
return response, updated_history, emotion_summary
|
267 |
-
|
268 |
# Gradio interface
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
)
|
285 |
-
|
286 |
-
|
287 |
-
reset_button = gr.Button("Reset Emotions")
|
288 |
-
reset_button.click(
|
289 |
-
fn=reset_emotions,
|
290 |
-
inputs=None,
|
291 |
-
outputs=gr.Textbox(label="Current Emotional State", lines=10),
|
292 |
-
api_name="reset_emotions"
|
293 |
-
)
|
294 |
|
295 |
if __name__ == "__main__":
|
296 |
-
|
|
|
41 |
emotion_classes = pd.Categorical(df['emotion']).categories
|
42 |
|
43 |
# Load pre-trained BERT model for emotion prediction
|
44 |
+
emotion_prediction_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
|
45 |
+
emotion_prediction_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
|
46 |
|
47 |
# Lazy loading for the fine-tuned language model (DialoGPT)
|
48 |
_finetuned_lm_tokenizer = None
|
|
|
154 |
emotion_prediction_pipeline = pipeline('text-classification', model=emotion_prediction_model, tokenizer=emotion_prediction_tokenizer, top_k=None)
|
155 |
predictions = emotion_prediction_pipeline(context)
|
156 |
emotion_scores = {prediction['label']: prediction['score'] for prediction in predictions[0]}
|
157 |
+
predicted_emotion = max(emotion_scores, key=emotion_scores.get)
|
158 |
|
159 |
+
# Map the predicted emotion to our emotion categories
|
160 |
emotion_mapping = {
|
161 |
+
'sadness': 'sadness',
|
162 |
+
'joy': 'joy',
|
163 |
+
'love': 'pleasure',
|
164 |
+
'anger': 'anger',
|
165 |
+
'fear': 'fear',
|
166 |
+
'surprise': 'surprise'
|
167 |
}
|
168 |
|
169 |
+
return emotion_mapping.get(predicted_emotion, 'neutral')
|
170 |
|
171 |
def generate_text(prompt, emotion=None, max_length=100):
|
172 |
finetuned_lm_tokenizer, finetuned_lm_model = get_finetuned_lm_model()
|
|
|
231 |
emotions[emotion]['percentage'] = 10
|
232 |
emotions[emotion]['intensity'] = 0
|
233 |
emotions['ideal_state']['percentage'] = 100
|
234 |
+
return get_emotion_summary()
|
235 |
|
236 |
def respond_to_user(user_input, chat_history):
|
237 |
# Predict the emotion from the user input
|
|
|
262 |
|
263 |
return response, chat_history, get_emotion_summary()
|
264 |
|
|
|
|
|
|
|
|
|
265 |
# Gradio interface
|
266 |
+
with gr.Blocks() as demo:
|
267 |
+
gr.Markdown("# Emotion-Aware AI Chatbot")
|
268 |
+
gr.Markdown("Chat with an AI that understands and responds to emotions.")
|
269 |
+
|
270 |
+
chatbot = gr.Chatbot()
|
271 |
+
msg = gr.Textbox(label="Type your message here...")
|
272 |
+
clear = gr.Button("Clear")
|
273 |
+
|
274 |
+
emotion_state = gr.Textbox(label="Current Emotional State", lines=10)
|
275 |
+
reset_button = gr.Button("Reset Emotions")
|
276 |
+
|
277 |
+
def user(user_message, history):
|
278 |
+
response, updated_history, emotion_summary = respond_to_user(user_message, history)
|
279 |
+
return "", updated_history, emotion_summary
|
280 |
+
|
281 |
+
msg.submit(user, [msg, chatbot], [msg, chatbot, emotion_state])
|
282 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
283 |
+
reset_button.click(reset_emotions, None, emotion_state, queue=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
if __name__ == "__main__":
|
286 |
+
demo.launch()
|