mac9087 commited on
Commit
d1d82fe
·
verified ·
1 Parent(s): af08117

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -126
app.py CHANGED
@@ -5,149 +5,53 @@ from transformers import pipeline
5
  from TTS.api import TTS
6
  import tempfile
7
  import os
8
- import re
9
 
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
- # Load models once at startup for better performance
14
- print("Loading AI models...")
15
  whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
16
  llm = pipeline("text-generation", model="tiiuae/falcon-rw-1b", max_new_tokens=100)
17
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
18
- print("All models loaded successfully!")
19
-
20
- def extract_ai_response(full_text, user_input):
21
- """
22
- Extract only the AI's response from the generated text using multiple strategies.
23
- This helps prevent the TTS engine from repeating the user's input.
24
- """
25
- # Strategy 1: Try to find text after "AI:" marker
26
- if "AI:" in full_text:
27
- try:
28
- return full_text.split("AI:")[1].strip()
29
- except IndexError:
30
- pass # Fall through to next strategy
31
-
32
- # Strategy 2: Try to find text after the user input
33
- if user_input in full_text:
34
- try:
35
- return full_text[full_text.find(user_input) + len(user_input):].strip()
36
- except:
37
- pass # Fall through to next strategy
38
-
39
- # Strategy 3: Try to split by sentences and remove the first one (likely the input)
40
- try:
41
- sentences = re.split(r'[.!?]\s+', full_text)
42
- if len(sentences) > 1:
43
- return ' '.join(sentences[1:]).strip()
44
- except:
45
- pass # Fall through to fallback
46
-
47
- # Fallback: Return the original text if all else fails
48
- return full_text.strip()
49
 
50
  @app.route("/talk", methods=["POST"])
51
  def talk():
52
- """
53
- Process audio from the user:
54
- 1. Transcribe the audio to text
55
- 2. Generate an AI response to the transcription
56
- 3. Convert the AI response to speech
57
- 4. Return the speech audio file
58
- """
59
  if "audio" not in request.files:
60
  return jsonify({"error": "No audio file"}), 400
61
-
62
- # Create a temporary file for the input audio
63
- input_audio_path = None
64
- output_audio_path = None
65
-
66
- try:
67
- # Save the uploaded audio to a temporary file
68
- audio_file = request.files["audio"]
69
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
70
- input_audio_path = tmp.name
71
- audio_file.save(input_audio_path)
72
-
73
- # Transcribe the audio to text
74
- segments, _ = whisper_model.transcribe(input_audio_path)
75
- transcription = "".join([seg.text for seg in segments]).strip()
76
-
77
- # Check if transcription was successful
78
- if not transcription:
79
- return jsonify({"error": "Could not transcribe audio"}), 400
80
-
81
- print(f"Transcribed: '{transcription}'")
82
-
83
- # Generate AI response
84
- prompt = f"User: {transcription}\nAI:"
85
- response_raw = llm(prompt)[0]["generated_text"]
86
-
87
- # Extract only the AI's response
88
- ai_response = extract_ai_response(response_raw, transcription)
89
- print(f"AI Response: '{ai_response}'")
90
-
91
- # Generate speech from the AI response
92
- output_audio_path = tempfile.mktemp(suffix=".wav")
93
- tts.tts_to_file(text=ai_response, file_path=output_audio_path)
94
-
95
- # Return the audio file
96
- return send_file(
97
- output_audio_path,
98
- mimetype="audio/wav",
99
- as_attachment=True,
100
- download_name="ai_response.wav"
101
- )
102
-
103
- except Exception as e:
104
- print(f"Error in /talk: {str(e)}")
105
- return jsonify({"error": str(e)}), 500
106
-
107
- finally:
108
- # Clean up the input audio file
109
- if input_audio_path and os.path.exists(input_audio_path):
110
- try:
111
- os.unlink(input_audio_path)
112
- except Exception as e:
113
- print(f"Error deleting input file: {e}")
114
-
115
- # Note: We don't delete the output file here as Flask will handle that
116
- # after the client has downloaded it
117
 
118
  @app.route("/chat", methods=["POST"])
119
  def chat():
120
- """
121
- Process text input from the user:
122
- 1. Generate an AI response to the input
123
- 2. Return the response as JSON
124
- """
125
- try:
126
- data = request.get_json()
127
- if not data or "text" not in data:
128
- return jsonify({"error": "Missing 'text' in request body"}), 400
129
-
130
- user_input = data["text"].strip()
131
- if not user_input:
132
- return jsonify({"error": "Empty input"}), 400
133
-
134
- # Generate AI response
135
- prompt = f"User: {user_input}\nAI:"
136
- response_raw = llm(prompt)[0]["generated_text"]
137
-
138
- # Extract only the AI's response
139
- ai_response = extract_ai_response(response_raw, user_input)
140
-
141
- return jsonify({"response": ai_response})
142
-
143
- except Exception as e:
144
- print(f"Error in /chat: {str(e)}")
145
- return jsonify({"error": str(e)}), 500
146
 
147
  @app.route("/")
148
  def index():
149
- """Simple route to check if the API is running"""
150
- return "Metaverse AI Character API running. Models loaded and ready."
151
 
152
  if __name__ == "__main__":
153
- app.run(host="0.0.0.0", port=7860, debug=True)
 
5
  from TTS.api import TTS
6
  import tempfile
7
  import os
 
8
 
9
  app = Flask(__name__)
10
  CORS(app)
11
 
12
+ # Load models
 
13
  whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
14
  llm = pipeline("text-generation", model="tiiuae/falcon-rw-1b", max_new_tokens=100)
15
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @app.route("/talk", methods=["POST"])
18
  def talk():
 
 
 
 
 
 
 
19
  if "audio" not in request.files:
20
  return jsonify({"error": "No audio file"}), 400
21
+
22
+ # Save audio
23
+ audio_file = request.files["audio"]
24
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
25
+ audio_path = tmp.name
26
+ audio_file.save(audio_path)
27
+
28
+ # Transcribe
29
+ segments, _ = whisper_model.transcribe(audio_path)
30
+ transcription = "".join([seg.text for seg in segments])
31
+
32
+ # Generate response
33
+ response_text = llm(transcription)[0]["generated_text"]
34
+
35
+ # Synthesize speech
36
+ tts_audio_path = audio_path.replace(".wav", "_reply.wav")
37
+ tts.tts_to_file(text=response_text, file_path=tts_audio_path)
38
+
39
+ return send_file(tts_audio_path, mimetype="audio/wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  @app.route("/chat", methods=["POST"])
42
  def chat():
43
+ data = request.get_json()
44
+ if not data or "text" not in data:
45
+ return jsonify({"error": "Missing 'text' in request body"}), 400
46
+
47
+ user_input = data["text"]
48
+ response = llm(user_input)[0]["generated_text"]
49
+
50
+ return jsonify({"response": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  @app.route("/")
53
  def index():
54
+ return "Metaverse AI Character API running."
 
55
 
56
  if __name__ == "__main__":
57
+ app.run(host="0.0.0.0", port=7860)