Update app.py
Browse files
app.py
CHANGED
@@ -5,37 +5,101 @@ from transformers import pipeline
|
|
5 |
from TTS.api import TTS
|
6 |
import tempfile
|
7 |
import os
|
|
|
8 |
|
9 |
app = Flask(__name__)
|
10 |
CORS(app)
|
11 |
|
12 |
# Load models
|
13 |
whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
@app.route("/talk", methods=["POST"])
|
18 |
def talk():
|
19 |
if "audio" not in request.files:
|
20 |
return jsonify({"error": "No audio file"}), 400
|
21 |
-
|
22 |
# Save audio
|
23 |
audio_file = request.files["audio"]
|
24 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
25 |
audio_path = tmp.name
|
26 |
audio_file.save(audio_path)
|
27 |
-
|
28 |
# Transcribe
|
29 |
segments, _ = whisper_model.transcribe(audio_path)
|
30 |
transcription = "".join([seg.text for seg in segments])
|
31 |
-
|
32 |
# Generate response
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
35 |
# Synthesize speech
|
36 |
tts_audio_path = audio_path.replace(".wav", "_reply.wav")
|
37 |
-
tts.tts_to_file(text=
|
38 |
-
|
39 |
return send_file(tts_audio_path, mimetype="audio/wav")
|
40 |
|
41 |
@app.route("/chat", methods=["POST"])
|
@@ -43,11 +107,14 @@ def chat():
|
|
43 |
data = request.get_json()
|
44 |
if not data or "text" not in data:
|
45 |
return jsonify({"error": "Missing 'text' in request body"}), 400
|
46 |
-
|
47 |
user_input = data["text"]
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
51 |
|
52 |
@app.route("/")
|
53 |
def index():
|
|
|
5 |
from TTS.api import TTS
|
6 |
import tempfile
|
7 |
import os
|
8 |
+
import re
|
9 |
|
10 |
app = Flask(__name__)
|
11 |
CORS(app)
|
12 |
|
13 |
# Load models
|
14 |
whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
|
15 |
+
|
16 |
+
# Configure the LLM for short, conversational responses
|
17 |
+
llm = pipeline(
|
18 |
+
"text-generation",
|
19 |
+
model="tiiuae/falcon-rw-1b",
|
20 |
+
max_new_tokens=50, # Reduced token count for shorter responses
|
21 |
+
)
|
22 |
+
|
23 |
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
|
24 |
|
25 |
+
def process_response(input_text, generated_text):
|
26 |
+
# Remove the input text from the beginning of the response
|
27 |
+
if generated_text.startswith(input_text):
|
28 |
+
clean_response = generated_text[len(input_text):].strip()
|
29 |
+
else:
|
30 |
+
clean_response = generated_text.strip()
|
31 |
+
|
32 |
+
# Split into sentences and take only the first 1-2 meaningful sentences
|
33 |
+
sentences = re.split(r'(?<=[.!?])\s+', clean_response)
|
34 |
+
|
35 |
+
# Filter out empty or very short sentences
|
36 |
+
meaningful_sentences = [s for s in sentences if len(s) > 5]
|
37 |
+
|
38 |
+
# Take just 1-2 sentences for a casual, human-like response
|
39 |
+
if meaningful_sentences:
|
40 |
+
if len(meaningful_sentences) > 2:
|
41 |
+
result = " ".join(meaningful_sentences[:2])
|
42 |
+
else:
|
43 |
+
result = " ".join(meaningful_sentences)
|
44 |
+
else:
|
45 |
+
# Fallback if no good sentences were found
|
46 |
+
result = "I'm not sure what to say about that."
|
47 |
+
|
48 |
+
# Remove any repetitive phrases
|
49 |
+
result = remove_repetitions(result)
|
50 |
+
|
51 |
+
return result
|
52 |
+
|
53 |
+
def remove_repetitions(text):
|
54 |
+
# Simple repetition removal
|
55 |
+
words = text.split()
|
56 |
+
if len(words) <= 5: # Don't process very short responses
|
57 |
+
return text
|
58 |
+
|
59 |
+
result = []
|
60 |
+
for i in range(len(words)):
|
61 |
+
# Check if this word starts a repeated phrase
|
62 |
+
if i < len(words) - 3: # Need at least 3 words to check for repetition
|
63 |
+
# Check if next 3+ words appear earlier in the text
|
64 |
+
is_repetition = False
|
65 |
+
for j in range(3, min(10, len(words) - i)): # Check phrases of length 3 to 10
|
66 |
+
phrase = " ".join(words[i:i+j])
|
67 |
+
if phrase in " ".join(result):
|
68 |
+
is_repetition = True
|
69 |
+
break
|
70 |
+
|
71 |
+
if not is_repetition:
|
72 |
+
result.append(words[i])
|
73 |
+
else:
|
74 |
+
result.append(words[i])
|
75 |
+
|
76 |
+
return " ".join(result)
|
77 |
+
|
78 |
@app.route("/talk", methods=["POST"])
|
79 |
def talk():
|
80 |
if "audio" not in request.files:
|
81 |
return jsonify({"error": "No audio file"}), 400
|
82 |
+
|
83 |
# Save audio
|
84 |
audio_file = request.files["audio"]
|
85 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
86 |
audio_path = tmp.name
|
87 |
audio_file.save(audio_path)
|
88 |
+
|
89 |
# Transcribe
|
90 |
segments, _ = whisper_model.transcribe(audio_path)
|
91 |
transcription = "".join([seg.text for seg in segments])
|
92 |
+
|
93 |
# Generate response
|
94 |
+
raw_response = llm(transcription)[0]["generated_text"]
|
95 |
+
|
96 |
+
# Process to get clean, short response
|
97 |
+
final_response = process_response(transcription, raw_response)
|
98 |
+
|
99 |
# Synthesize speech
|
100 |
tts_audio_path = audio_path.replace(".wav", "_reply.wav")
|
101 |
+
tts.tts_to_file(text=final_response, file_path=tts_audio_path)
|
102 |
+
|
103 |
return send_file(tts_audio_path, mimetype="audio/wav")
|
104 |
|
105 |
@app.route("/chat", methods=["POST"])
|
|
|
107 |
data = request.get_json()
|
108 |
if not data or "text" not in data:
|
109 |
return jsonify({"error": "Missing 'text' in request body"}), 400
|
110 |
+
|
111 |
user_input = data["text"]
|
112 |
+
raw_response = llm(user_input)[0]["generated_text"]
|
113 |
+
|
114 |
+
# Process to get clean, short response
|
115 |
+
final_response = process_response(user_input, raw_response)
|
116 |
+
|
117 |
+
return jsonify({"response": final_response})
|
118 |
|
119 |
@app.route("/")
|
120 |
def index():
|