mac9087 commited on
Commit
c310c35
·
verified ·
1 Parent(s): b3b50b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -25
app.py CHANGED
@@ -23,12 +23,27 @@ llm = pipeline(
23
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
24
 
25
  def process_response(input_text, generated_text):
 
 
 
 
 
 
 
 
 
 
 
26
  # Remove the input text from the beginning of the response
27
- if generated_text.startswith(input_text):
28
  clean_response = generated_text[len(input_text):].strip()
29
  else:
30
  clean_response = generated_text.strip()
31
 
 
 
 
 
32
  # Split into sentences and take only the first 1-2 meaningful sentences
33
  sentences = re.split(r'(?<=[.!?])\s+', clean_response)
34
 
@@ -42,8 +57,13 @@ def process_response(input_text, generated_text):
42
  else:
43
  result = " ".join(meaningful_sentences)
44
  else:
45
- # Fallback if no good sentences were found
46
- result = "I'm not sure what to say about that."
 
 
 
 
 
47
 
48
  # Remove any repetitive phrases
49
  result = remove_repetitions(result)
@@ -86,35 +106,63 @@ def talk():
86
  audio_path = tmp.name
87
  audio_file.save(audio_path)
88
 
89
- # Transcribe
90
- segments, _ = whisper_model.transcribe(audio_path)
91
- transcription = "".join([seg.text for seg in segments])
92
-
93
- # Generate response
94
- raw_response = llm(transcription)[0]["generated_text"]
95
-
96
- # Process to get clean, short response
97
- final_response = process_response(transcription, raw_response)
98
-
99
- # Synthesize speech
100
- tts_audio_path = audio_path.replace(".wav", "_reply.wav")
101
- tts.tts_to_file(text=final_response, file_path=tts_audio_path)
102
-
103
- return send_file(tts_audio_path, mimetype="audio/wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  @app.route("/chat", methods=["POST"])
106
  def chat():
107
  data = request.get_json()
108
  if not data or "text" not in data:
109
  return jsonify({"error": "Missing 'text' in request body"}), 400
110
-
111
- user_input = data["text"]
112
- raw_response = llm(user_input)[0]["generated_text"]
113
 
114
- # Process to get clean, short response
115
- final_response = process_response(user_input, raw_response)
116
-
117
- return jsonify({"response": final_response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  @app.route("/")
120
  def index():
 
23
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
24
 
25
  def process_response(input_text, generated_text):
26
+ # Handle the case where generated_text might be None
27
+ if not generated_text:
28
+ return "I'm not sure what to say about that."
29
+
30
+ # Make sure both are strings
31
+ input_text = str(input_text).strip()
32
+ generated_text = str(generated_text).strip()
33
+
34
+ # Skip empty input
35
+ if not input_text:
36
+ clean_response = generated_text
37
  # Remove the input text from the beginning of the response
38
+ elif generated_text.startswith(input_text):
39
  clean_response = generated_text[len(input_text):].strip()
40
  else:
41
  clean_response = generated_text.strip()
42
 
43
+ # If we ended up with nothing, provide a default response
44
+ if not clean_response:
45
+ return "I'm listening."
46
+
47
  # Split into sentences and take only the first 1-2 meaningful sentences
48
  sentences = re.split(r'(?<=[.!?])\s+', clean_response)
49
 
 
57
  else:
58
  result = " ".join(meaningful_sentences)
59
  else:
60
+ # If no meaningful sentences, but we have short sentences, use those
61
+ if sentences and any(s.strip() for s in sentences):
62
+ short_sentences = [s for s in sentences if s.strip()]
63
+ result = " ".join(short_sentences[:2])
64
+ else:
65
+ # Fallback if no good sentences were found
66
+ result = "I'm not sure what to say about that."
67
 
68
  # Remove any repetitive phrases
69
  result = remove_repetitions(result)
 
106
  audio_path = tmp.name
107
  audio_file.save(audio_path)
108
 
109
+ try:
110
+ # Transcribe
111
+ segments, _ = whisper_model.transcribe(audio_path)
112
+ transcription = "".join([seg.text for seg in segments])
113
+
114
+ print(f"Transcription: {transcription}") # Debugging
115
+
116
+ if not transcription.strip():
117
+ # Handle empty transcription
118
+ final_response = "I didn't catch that. Could you please speak again?"
119
+ else:
120
+ # Generate response
121
+ raw_response = llm(transcription)[0]["generated_text"]
122
+
123
+ # Process to get clean, short response
124
+ final_response = process_response(transcription, raw_response)
125
+
126
+ print(f"Response: {final_response}") # Debugging
127
+
128
+ # Synthesize speech
129
+ tts_audio_path = audio_path.replace(".wav", "_reply.wav")
130
+ tts.tts_to_file(text=final_response, file_path=tts_audio_path)
131
+
132
+ # Return both the audio file and the text response
133
+ response = send_file(tts_audio_path, mimetype="audio/wav")
134
+ response.headers["X-Response-Text"] = final_response
135
+
136
+ return response
137
+ except Exception as e:
138
+ print(f"Error in talk endpoint: {str(e)}")
139
+ return jsonify({"error": str(e)}), 500
140
 
141
  @app.route("/chat", methods=["POST"])
142
  def chat():
143
  data = request.get_json()
144
  if not data or "text" not in data:
145
  return jsonify({"error": "Missing 'text' in request body"}), 400
 
 
 
146
 
147
+ try:
148
+ user_input = data["text"]
149
+ print(f"Text input: {user_input}") # Debugging
150
+
151
+ # Handle empty or too short input
152
+ if not user_input or len(user_input.strip()) < 2:
153
+ return jsonify({"response": "I'm listening. Please say more."})
154
+
155
+ # Generate response
156
+ raw_response = llm(user_input)[0]["generated_text"]
157
+
158
+ # Process to get clean, short response
159
+ final_response = process_response(user_input, raw_response)
160
+ print(f"Text response: {final_response}") # Debugging
161
+
162
+ return jsonify({"response": final_response})
163
+ except Exception as e:
164
+ print(f"Error in chat endpoint: {str(e)}")
165
+ return jsonify({"response": "I'm having trouble processing that. Could you try again?", "error": str(e)})
166
 
167
  @app.route("/")
168
  def index():