Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ import torch
|
|
9 |
from sklearn.ensemble import RandomForestClassifier
|
10 |
from sklearn.model_selection import train_test_split
|
11 |
from sklearn.preprocessing import OneHotEncoder
|
12 |
-
from transformers import
|
13 |
from deap import base, creator, tools, algorithms
|
14 |
import gc
|
15 |
|
@@ -44,18 +44,15 @@ emotion_classes = pd.Categorical(df['emotion']).categories
|
|
44 |
emotion_prediction_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
|
45 |
emotion_prediction_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
|
46 |
|
47 |
-
# Lazy loading for the fine-tuned language model (
|
48 |
_finetuned_lm_tokenizer = None
|
49 |
_finetuned_lm_model = None
|
50 |
|
51 |
def get_finetuned_lm_model():
|
52 |
global _finetuned_lm_tokenizer, _finetuned_lm_model
|
53 |
if _finetuned_lm_tokenizer is None or _finetuned_lm_model is None:
|
54 |
-
_finetuned_lm_tokenizer =
|
55 |
-
|
56 |
-
_finetuned_lm_tokenizer.pad_token = _finetuned_lm_tokenizer.eos_token # Set pad token to eos token
|
57 |
-
_finetuned_lm_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium", device_map="auto", low_cpu_mem_usage=True)
|
58 |
-
_finetuned_lm_model.config.pad_token_id = _finetuned_lm_model.config.eos_token_id # Set pad token id in the model config
|
59 |
return _finetuned_lm_tokenizer, _finetuned_lm_model
|
60 |
|
61 |
# Enhanced Emotional States
|
@@ -131,6 +128,7 @@ def evolve_emotions():
|
|
131 |
toolbox.register("individual", tools.initCycle, creator.Individual,
|
132 |
(toolbox.attr_float,) * (len(emotions) - 1) +
|
133 |
(toolbox.attr_intensity,) * len(emotions) +
|
|
|
134 |
(lambda: 100,), n=1)
|
135 |
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
|
136 |
toolbox.register("mate", tools.cxTwoPoint)
|
@@ -174,42 +172,32 @@ def predict_emotion(context):
|
|
174 |
def generate_text(prompt, chat_history, emotion=None, max_length=100):
|
175 |
finetuned_lm_tokenizer, finetuned_lm_model = get_finetuned_lm_model()
|
176 |
|
177 |
-
|
|
|
178 |
for turn in chat_history[-5:]: # Consider last 5 turns for context
|
179 |
-
full_prompt += f"{
|
180 |
-
full_prompt += f"{
|
181 |
|
182 |
-
input_ids = finetuned_lm_tokenizer
|
183 |
|
184 |
if torch.cuda.is_available():
|
185 |
input_ids = input_ids.cuda()
|
186 |
finetuned_lm_model = finetuned_lm_model.cuda()
|
187 |
|
188 |
-
# Set up the emotion-specific generation parameters
|
189 |
-
if emotion:
|
190 |
-
# You can adjust these parameters based on the emotion
|
191 |
-
temperature = 0.7
|
192 |
-
top_k = 50
|
193 |
-
top_p = 0.9
|
194 |
-
else:
|
195 |
-
temperature = 1.0
|
196 |
-
top_k = 0
|
197 |
-
top_p = 1.0
|
198 |
-
|
199 |
# Generate the response
|
200 |
-
|
201 |
input_ids,
|
202 |
-
max_length=max_length,
|
203 |
num_return_sequences=1,
|
204 |
no_repeat_ngram_size=2,
|
205 |
do_sample=True,
|
206 |
-
temperature=
|
207 |
-
top_k=
|
208 |
-
top_p=
|
209 |
)
|
210 |
|
211 |
-
generated_text = finetuned_lm_tokenizer.decode(
|
212 |
-
return generated_text
|
213 |
|
214 |
def update_emotion_history(emotion, intensity):
|
215 |
global emotion_history
|
@@ -270,8 +258,8 @@ def respond_to_user(user_input, chat_history):
|
|
270 |
|
271 |
# Gradio interface
|
272 |
with gr.Blocks() as demo:
|
273 |
-
gr.Markdown("# Emotion-Aware AI Chatbot")
|
274 |
-
gr.Markdown("Chat with an AI that understands and responds to emotions.")
|
275 |
|
276 |
chatbot = gr.Chatbot()
|
277 |
msg = gr.Textbox(label="Type your message here...")
|
|
|
9 |
from sklearn.ensemble import RandomForestClassifier
|
10 |
from sklearn.model_selection import train_test_split
|
11 |
from sklearn.preprocessing import OneHotEncoder
|
12 |
+
from transformers import AutoModelForSequenceClassification, T5ForConditionalGeneration, T5Tokenizer, pipeline
|
13 |
from deap import base, creator, tools, algorithms
|
14 |
import gc
|
15 |
|
|
|
44 |
emotion_prediction_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
|
45 |
emotion_prediction_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
|
46 |
|
47 |
+
# Lazy loading for the fine-tuned language model (FLAN-T5)
|
48 |
_finetuned_lm_tokenizer = None
|
49 |
_finetuned_lm_model = None
|
50 |
|
51 |
def get_finetuned_lm_model():
|
52 |
global _finetuned_lm_tokenizer, _finetuned_lm_model
|
53 |
if _finetuned_lm_tokenizer is None or _finetuned_lm_model is None:
|
54 |
+
_finetuned_lm_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
55 |
+
_finetuned_lm_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto", low_cpu_mem_usage=True)
|
|
|
|
|
|
|
56 |
return _finetuned_lm_tokenizer, _finetuned_lm_model
|
57 |
|
58 |
# Enhanced Emotional States
|
|
|
128 |
toolbox.register("individual", tools.initCycle, creator.Individual,
|
129 |
(toolbox.attr_float,) * (len(emotions) - 1) +
|
130 |
(toolbox.attr_intensity,) * len(emotions) +
|
131 |
+
(lambda: 100,), n(toolbox.attr_intensity,) * len(emotions) +
|
132 |
(lambda: 100,), n=1)
|
133 |
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
|
134 |
toolbox.register("mate", tools.cxTwoPoint)
|
|
|
172 |
def generate_text(prompt, chat_history, emotion=None, max_length=100):
|
173 |
finetuned_lm_tokenizer, finetuned_lm_model = get_finetuned_lm_model()
|
174 |
|
175 |
+
# Prepare the input by concatenating the chat history and the new prompt
|
176 |
+
full_prompt = "You are Adam, an AI assistant. Respond to the following conversation in a natural and engaging way, considering the emotion: " + emotion + "\n\n"
|
177 |
for turn in chat_history[-5:]: # Consider last 5 turns for context
|
178 |
+
full_prompt += f"Human: {turn[0]}\nAdam: {turn[1]}\n"
|
179 |
+
full_prompt += f"Human: {prompt}\nAdam:"
|
180 |
|
181 |
+
input_ids = finetuned_lm_tokenizer(full_prompt, return_tensors="pt", max_length=512, truncation=True).input_ids
|
182 |
|
183 |
if torch.cuda.is_available():
|
184 |
input_ids = input_ids.cuda()
|
185 |
finetuned_lm_model = finetuned_lm_model.cuda()
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
# Generate the response
|
188 |
+
outputs = finetuned_lm_model.generate(
|
189 |
input_ids,
|
190 |
+
max_length=max_length + input_ids.shape[1],
|
191 |
num_return_sequences=1,
|
192 |
no_repeat_ngram_size=2,
|
193 |
do_sample=True,
|
194 |
+
temperature=0.7,
|
195 |
+
top_k=50,
|
196 |
+
top_p=0.95
|
197 |
)
|
198 |
|
199 |
+
generated_text = finetuned_lm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
200 |
+
return generated_text.strip()
|
201 |
|
202 |
def update_emotion_history(emotion, intensity):
|
203 |
global emotion_history
|
|
|
258 |
|
259 |
# Gradio interface
|
260 |
with gr.Blocks() as demo:
|
261 |
+
gr.Markdown("# Adam: Emotion-Aware AI Chatbot")
|
262 |
+
gr.Markdown("Chat with Adam, an AI that understands and responds to emotions.")
|
263 |
|
264 |
chatbot = gr.Chatbot()
|
265 |
msg = gr.Textbox(label="Type your message here...")
|