mac9087 commited on
Commit
af08117
·
verified ·
1 Parent(s): 7b16fc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -30
app.py CHANGED
@@ -5,53 +5,149 @@ from transformers import pipeline
5
  from TTS.api import TTS
6
  import tempfile
7
  import os
 
8
 
9
  app = Flask(__name__)
10
  CORS(app)
11
 
12
- # Load models
 
13
  whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
14
  llm = pipeline("text-generation", model="tiiuae/falcon-rw-1b", max_new_tokens=100)
15
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @app.route("/talk", methods=["POST"])
18
  def talk():
 
 
 
 
 
 
 
19
  if "audio" not in request.files:
20
  return jsonify({"error": "No audio file"}), 400
21
-
22
- # Save audio
23
- audio_file = request.files["audio"]
24
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
25
- audio_path = tmp.name
26
- audio_file.save(audio_path)
27
-
28
- # Transcribe
29
- segments, _ = whisper_model.transcribe(audio_path)
30
- transcription = "".join([seg.text for seg in segments])
31
-
32
- # Generate response
33
- response_text = llm(transcription)[0]["generated_text"]
34
-
35
- # Synthesize speech
36
- tts_audio_path = audio_path.replace(".wav", "_reply.wav")
37
- tts.tts_to_file(text=response_text, file_path=tts_audio_path)
38
-
39
- return send_file(tts_audio_path, mimetype="audio/wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  @app.route("/chat", methods=["POST"])
42
  def chat():
43
- data = request.get_json()
44
- if not data or "text" not in data:
45
- return jsonify({"error": "Missing 'text' in request body"}), 400
46
-
47
- user_input = data["text"]
48
- response = llm(user_input)[0]["generated_text"]
49
-
50
- return jsonify({"response": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  @app.route("/")
53
  def index():
54
- return "Metaverse AI Character API running."
 
55
 
56
  if __name__ == "__main__":
57
- app.run(host="0.0.0.0", port=7860)
 
5
  from TTS.api import TTS
6
  import tempfile
7
  import os
8
+ import re
9
 
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
+ # Load models once at startup for better performance
14
+ print("Loading AI models...")
15
  whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
16
  llm = pipeline("text-generation", model="tiiuae/falcon-rw-1b", max_new_tokens=100)
17
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
18
+ print("All models loaded successfully!")
19
+
20
+ def extract_ai_response(full_text, user_input):
21
+ """
22
+ Extract only the AI's response from the generated text using multiple strategies.
23
+ This helps prevent the TTS engine from repeating the user's input.
24
+ """
25
+ # Strategy 1: Try to find text after "AI:" marker
26
+ if "AI:" in full_text:
27
+ try:
28
+ return full_text.split("AI:")[1].strip()
29
+ except IndexError:
30
+ pass # Fall through to next strategy
31
+
32
+ # Strategy 2: Try to find text after the user input
33
+ if user_input in full_text:
34
+ try:
35
+ return full_text[full_text.find(user_input) + len(user_input):].strip()
36
+ except:
37
+ pass # Fall through to next strategy
38
+
39
+ # Strategy 3: Try to split by sentences and remove the first one (likely the input)
40
+ try:
41
+ sentences = re.split(r'[.!?]\s+', full_text)
42
+ if len(sentences) > 1:
43
+ return ' '.join(sentences[1:]).strip()
44
+ except:
45
+ pass # Fall through to fallback
46
+
47
+ # Fallback: Return the original text if all else fails
48
+ return full_text.strip()
49
 
50
  @app.route("/talk", methods=["POST"])
51
  def talk():
52
+ """
53
+ Process audio from the user:
54
+ 1. Transcribe the audio to text
55
+ 2. Generate an AI response to the transcription
56
+ 3. Convert the AI response to speech
57
+ 4. Return the speech audio file
58
+ """
59
  if "audio" not in request.files:
60
  return jsonify({"error": "No audio file"}), 400
61
+
62
+ # Create a temporary file for the input audio
63
+ input_audio_path = None
64
+ output_audio_path = None
65
+
66
+ try:
67
+ # Save the uploaded audio to a temporary file
68
+ audio_file = request.files["audio"]
69
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
70
+ input_audio_path = tmp.name
71
+ audio_file.save(input_audio_path)
72
+
73
+ # Transcribe the audio to text
74
+ segments, _ = whisper_model.transcribe(input_audio_path)
75
+ transcription = "".join([seg.text for seg in segments]).strip()
76
+
77
+ # Check if transcription was successful
78
+ if not transcription:
79
+ return jsonify({"error": "Could not transcribe audio"}), 400
80
+
81
+ print(f"Transcribed: '{transcription}'")
82
+
83
+ # Generate AI response
84
+ prompt = f"User: {transcription}\nAI:"
85
+ response_raw = llm(prompt)[0]["generated_text"]
86
+
87
+ # Extract only the AI's response
88
+ ai_response = extract_ai_response(response_raw, transcription)
89
+ print(f"AI Response: '{ai_response}'")
90
+
91
+ # Generate speech from the AI response
92
+ output_audio_path = tempfile.mktemp(suffix=".wav")
93
+ tts.tts_to_file(text=ai_response, file_path=output_audio_path)
94
+
95
+ # Return the audio file
96
+ return send_file(
97
+ output_audio_path,
98
+ mimetype="audio/wav",
99
+ as_attachment=True,
100
+ download_name="ai_response.wav"
101
+ )
102
+
103
+ except Exception as e:
104
+ print(f"Error in /talk: {str(e)}")
105
+ return jsonify({"error": str(e)}), 500
106
+
107
+ finally:
108
+ # Clean up the input audio file
109
+ if input_audio_path and os.path.exists(input_audio_path):
110
+ try:
111
+ os.unlink(input_audio_path)
112
+ except Exception as e:
113
+ print(f"Error deleting input file: {e}")
114
+
115
+ # Note: We don't delete the output file here as Flask will handle that
116
+ # after the client has downloaded it
117
 
118
  @app.route("/chat", methods=["POST"])
119
  def chat():
120
+ """
121
+ Process text input from the user:
122
+ 1. Generate an AI response to the input
123
+ 2. Return the response as JSON
124
+ """
125
+ try:
126
+ data = request.get_json()
127
+ if not data or "text" not in data:
128
+ return jsonify({"error": "Missing 'text' in request body"}), 400
129
+
130
+ user_input = data["text"].strip()
131
+ if not user_input:
132
+ return jsonify({"error": "Empty input"}), 400
133
+
134
+ # Generate AI response
135
+ prompt = f"User: {user_input}\nAI:"
136
+ response_raw = llm(prompt)[0]["generated_text"]
137
+
138
+ # Extract only the AI's response
139
+ ai_response = extract_ai_response(response_raw, user_input)
140
+
141
+ return jsonify({"response": ai_response})
142
+
143
+ except Exception as e:
144
+ print(f"Error in /chat: {str(e)}")
145
+ return jsonify({"error": str(e)}), 500
146
 
147
  @app.route("/")
148
  def index():
149
+ """Simple route to check if the API is running"""
150
+ return "Metaverse AI Character API running. Models loaded and ready."
151
 
152
  if __name__ == "__main__":
153
+ app.run(host="0.0.0.0", port=7860, debug=True)