Update app.py
Browse files
app.py
CHANGED
@@ -23,12 +23,27 @@ llm = pipeline(
|
|
23 |
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
|
24 |
|
25 |
def process_response(input_text, generated_text):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# Remove the input text from the beginning of the response
|
27 |
-
|
28 |
clean_response = generated_text[len(input_text):].strip()
|
29 |
else:
|
30 |
clean_response = generated_text.strip()
|
31 |
|
|
|
|
|
|
|
|
|
32 |
# Split into sentences and take only the first 1-2 meaningful sentences
|
33 |
sentences = re.split(r'(?<=[.!?])\s+', clean_response)
|
34 |
|
@@ -42,8 +57,13 @@ def process_response(input_text, generated_text):
|
|
42 |
else:
|
43 |
result = " ".join(meaningful_sentences)
|
44 |
else:
|
45 |
-
#
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
# Remove any repetitive phrases
|
49 |
result = remove_repetitions(result)
|
@@ -86,35 +106,63 @@ def talk():
|
|
86 |
audio_path = tmp.name
|
87 |
audio_file.save(audio_path)
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
@app.route("/chat", methods=["POST"])
|
106 |
def chat():
|
107 |
data = request.get_json()
|
108 |
if not data or "text" not in data:
|
109 |
return jsonify({"error": "Missing 'text' in request body"}), 400
|
110 |
-
|
111 |
-
user_input = data["text"]
|
112 |
-
raw_response = llm(user_input)[0]["generated_text"]
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
@app.route("/")
|
120 |
def index():
|
|
|
23 |
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
|
24 |
|
25 |
def process_response(input_text, generated_text):
|
26 |
+
# Handle the case where generated_text might be None
|
27 |
+
if not generated_text:
|
28 |
+
return "I'm not sure what to say about that."
|
29 |
+
|
30 |
+
# Make sure both are strings
|
31 |
+
input_text = str(input_text).strip()
|
32 |
+
generated_text = str(generated_text).strip()
|
33 |
+
|
34 |
+
# Skip empty input
|
35 |
+
if not input_text:
|
36 |
+
clean_response = generated_text
|
37 |
# Remove the input text from the beginning of the response
|
38 |
+
elif generated_text.startswith(input_text):
|
39 |
clean_response = generated_text[len(input_text):].strip()
|
40 |
else:
|
41 |
clean_response = generated_text.strip()
|
42 |
|
43 |
+
# If we ended up with nothing, provide a default response
|
44 |
+
if not clean_response:
|
45 |
+
return "I'm listening."
|
46 |
+
|
47 |
# Split into sentences and take only the first 1-2 meaningful sentences
|
48 |
sentences = re.split(r'(?<=[.!?])\s+', clean_response)
|
49 |
|
|
|
57 |
else:
|
58 |
result = " ".join(meaningful_sentences)
|
59 |
else:
|
60 |
+
# If no meaningful sentences, but we have short sentences, use those
|
61 |
+
if sentences and any(s.strip() for s in sentences):
|
62 |
+
short_sentences = [s for s in sentences if s.strip()]
|
63 |
+
result = " ".join(short_sentences[:2])
|
64 |
+
else:
|
65 |
+
# Fallback if no good sentences were found
|
66 |
+
result = "I'm not sure what to say about that."
|
67 |
|
68 |
# Remove any repetitive phrases
|
69 |
result = remove_repetitions(result)
|
|
|
106 |
audio_path = tmp.name
|
107 |
audio_file.save(audio_path)
|
108 |
|
109 |
+
try:
|
110 |
+
# Transcribe
|
111 |
+
segments, _ = whisper_model.transcribe(audio_path)
|
112 |
+
transcription = "".join([seg.text for seg in segments])
|
113 |
+
|
114 |
+
print(f"Transcription: {transcription}") # Debugging
|
115 |
+
|
116 |
+
if not transcription.strip():
|
117 |
+
# Handle empty transcription
|
118 |
+
final_response = "I didn't catch that. Could you please speak again?"
|
119 |
+
else:
|
120 |
+
# Generate response
|
121 |
+
raw_response = llm(transcription)[0]["generated_text"]
|
122 |
+
|
123 |
+
# Process to get clean, short response
|
124 |
+
final_response = process_response(transcription, raw_response)
|
125 |
+
|
126 |
+
print(f"Response: {final_response}") # Debugging
|
127 |
+
|
128 |
+
# Synthesize speech
|
129 |
+
tts_audio_path = audio_path.replace(".wav", "_reply.wav")
|
130 |
+
tts.tts_to_file(text=final_response, file_path=tts_audio_path)
|
131 |
+
|
132 |
+
# Return both the audio file and the text response
|
133 |
+
response = send_file(tts_audio_path, mimetype="audio/wav")
|
134 |
+
response.headers["X-Response-Text"] = final_response
|
135 |
+
|
136 |
+
return response
|
137 |
+
except Exception as e:
|
138 |
+
print(f"Error in talk endpoint: {str(e)}")
|
139 |
+
return jsonify({"error": str(e)}), 500
|
140 |
|
141 |
@app.route("/chat", methods=["POST"])
|
142 |
def chat():
|
143 |
data = request.get_json()
|
144 |
if not data or "text" not in data:
|
145 |
return jsonify({"error": "Missing 'text' in request body"}), 400
|
|
|
|
|
|
|
146 |
|
147 |
+
try:
|
148 |
+
user_input = data["text"]
|
149 |
+
print(f"Text input: {user_input}") # Debugging
|
150 |
+
|
151 |
+
# Handle empty or too short input
|
152 |
+
if not user_input or len(user_input.strip()) < 2:
|
153 |
+
return jsonify({"response": "I'm listening. Please say more."})
|
154 |
+
|
155 |
+
# Generate response
|
156 |
+
raw_response = llm(user_input)[0]["generated_text"]
|
157 |
+
|
158 |
+
# Process to get clean, short response
|
159 |
+
final_response = process_response(user_input, raw_response)
|
160 |
+
print(f"Text response: {final_response}") # Debugging
|
161 |
+
|
162 |
+
return jsonify({"response": final_response})
|
163 |
+
except Exception as e:
|
164 |
+
print(f"Error in chat endpoint: {str(e)}")
|
165 |
+
return jsonify({"response": "I'm having trouble processing that. Could you try again?", "error": str(e)})
|
166 |
|
167 |
@app.route("/")
|
168 |
def index():
|