mac9087 commited on
Commit
ca5149c
·
verified ·
1 Parent(s): 0017945

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -14
app.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  app = Flask(__name__)
10
  CORS(app)
11
 
12
- # Load models
13
  whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
14
  llm = pipeline("text-generation", model="tiiuae/falcon-rw-1b", max_new_tokens=100)
15
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
@@ -18,24 +18,38 @@ tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False,
18
  def talk():
19
  if "audio" not in request.files:
20
  return jsonify({"error": "No audio file"}), 400
21
-
22
  # Save audio
23
  audio_file = request.files["audio"]
24
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
25
  audio_path = tmp.name
26
  audio_file.save(audio_path)
27
-
28
  # Transcribe
29
  segments, _ = whisper_model.transcribe(audio_path)
30
  transcription = "".join([seg.text for seg in segments])
31
-
32
- # Generate response
33
- response_text = llm(transcription)[0]["generated_text"]
34
-
35
- # Synthesize speech
 
 
 
 
 
 
 
 
36
  tts_audio_path = audio_path.replace(".wav", "_reply.wav")
37
- tts.tts_to_file(text=response_text, file_path=tts_audio_path)
38
-
 
 
 
 
 
 
39
  return send_file(tts_audio_path, mimetype="audio/wav")
40
 
41
  @app.route("/chat", methods=["POST"])
@@ -43,11 +57,19 @@ def chat():
43
  data = request.get_json()
44
  if not data or "text" not in data:
45
  return jsonify({"error": "Missing 'text' in request body"}), 400
46
-
47
  user_input = data["text"]
48
- response = llm(user_input)[0]["generated_text"]
49
-
50
- return jsonify({"response": response})
 
 
 
 
 
 
 
 
51
 
52
  @app.route("/")
53
  def index():
 
9
  app = Flask(__name__)
10
  CORS(app)
11
 
12
+ # Load models once at startup
13
  whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
14
  llm = pipeline("text-generation", model="tiiuae/falcon-rw-1b", max_new_tokens=100)
15
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
 
18
  def talk():
19
  if "audio" not in request.files:
20
  return jsonify({"error": "No audio file"}), 400
21
+
22
  # Save audio
23
  audio_file = request.files["audio"]
24
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
25
  audio_path = tmp.name
26
  audio_file.save(audio_path)
27
+
28
  # Transcribe
29
  segments, _ = whisper_model.transcribe(audio_path)
30
  transcription = "".join([seg.text for seg in segments])
31
+
32
+ # Generate response with a clear prompt format
33
+ prompt = f"User: {transcription}\nAI:"
34
+ response_raw = llm(prompt)[0]["generated_text"]
35
+
36
+ # Extract only the AI's response (everything after "AI:")
37
+ try:
38
+ ai_response = response_raw.split("AI:")[1].strip()
39
+ except:
40
+ # Fallback if splitting fails
41
+ ai_response = response_raw
42
+
43
+ # Synthesize speech using only the AI's response
44
  tts_audio_path = audio_path.replace(".wav", "_reply.wav")
45
+ tts.tts_to_file(text=ai_response, file_path=tts_audio_path)
46
+
47
+ # Clean up the original audio file
48
+ try:
49
+ os.unlink(audio_path)
50
+ except:
51
+ pass
52
+
53
  return send_file(tts_audio_path, mimetype="audio/wav")
54
 
55
  @app.route("/chat", methods=["POST"])
 
57
  data = request.get_json()
58
  if not data or "text" not in data:
59
  return jsonify({"error": "Missing 'text' in request body"}), 400
60
+
61
  user_input = data["text"]
62
+
63
+ # Same improvement for text chat
64
+ prompt = f"User: {user_input}\nAI:"
65
+ response_raw = llm(prompt)[0]["generated_text"]
66
+
67
+ try:
68
+ ai_response = response_raw.split("AI:")[1].strip()
69
+ except:
70
+ ai_response = response_raw
71
+
72
+ return jsonify({"response": ai_response})
73
 
74
  @app.route("/")
75
  def index():