mac9087 commited on
Commit
b3b50b5
·
verified ·
1 Parent(s): d1d82fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -12
app.py CHANGED
@@ -5,37 +5,101 @@ from transformers import pipeline
5
  from TTS.api import TTS
6
  import tempfile
7
  import os
 
8
 
9
  app = Flask(__name__)
10
  CORS(app)
11
 
12
  # Load models
13
  whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
14
- llm = pipeline("text-generation", model="tiiuae/falcon-rw-1b", max_new_tokens=100)
 
 
 
 
 
 
 
15
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @app.route("/talk", methods=["POST"])
18
  def talk():
19
  if "audio" not in request.files:
20
  return jsonify({"error": "No audio file"}), 400
21
-
22
  # Save audio
23
  audio_file = request.files["audio"]
24
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
25
  audio_path = tmp.name
26
  audio_file.save(audio_path)
27
-
28
  # Transcribe
29
  segments, _ = whisper_model.transcribe(audio_path)
30
  transcription = "".join([seg.text for seg in segments])
31
-
32
  # Generate response
33
- response_text = llm(transcription)[0]["generated_text"]
34
-
 
 
 
35
  # Synthesize speech
36
  tts_audio_path = audio_path.replace(".wav", "_reply.wav")
37
- tts.tts_to_file(text=response_text, file_path=tts_audio_path)
38
-
39
  return send_file(tts_audio_path, mimetype="audio/wav")
40
 
41
  @app.route("/chat", methods=["POST"])
@@ -43,11 +107,14 @@ def chat():
43
  data = request.get_json()
44
  if not data or "text" not in data:
45
  return jsonify({"error": "Missing 'text' in request body"}), 400
46
-
47
  user_input = data["text"]
48
- response = llm(user_input)[0]["generated_text"]
49
-
50
- return jsonify({"response": response})
 
 
 
51
 
52
  @app.route("/")
53
  def index():
 
5
  from TTS.api import TTS
6
  import tempfile
7
  import os
8
+ import re
9
 
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
  # Load models
14
  whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
15
+
16
+ # Configure the LLM for short, conversational responses
17
+ llm = pipeline(
18
+ "text-generation",
19
+ model="tiiuae/falcon-rw-1b",
20
+ max_new_tokens=50, # Reduced token count for shorter responses
21
+ )
22
+
23
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
24
 
25
+ def process_response(input_text, generated_text):
26
+ # Remove the input text from the beginning of the response
27
+ if generated_text.startswith(input_text):
28
+ clean_response = generated_text[len(input_text):].strip()
29
+ else:
30
+ clean_response = generated_text.strip()
31
+
32
+ # Split into sentences and take only the first 1-2 meaningful sentences
33
+ sentences = re.split(r'(?<=[.!?])\s+', clean_response)
34
+
35
+ # Filter out empty or very short sentences
36
+ meaningful_sentences = [s for s in sentences if len(s) > 5]
37
+
38
+ # Take just 1-2 sentences for a casual, human-like response
39
+ if meaningful_sentences:
40
+ if len(meaningful_sentences) > 2:
41
+ result = " ".join(meaningful_sentences[:2])
42
+ else:
43
+ result = " ".join(meaningful_sentences)
44
+ else:
45
+ # Fallback if no good sentences were found
46
+ result = "I'm not sure what to say about that."
47
+
48
+ # Remove any repetitive phrases
49
+ result = remove_repetitions(result)
50
+
51
+ return result
52
+
53
+ def remove_repetitions(text):
54
+ # Simple repetition removal
55
+ words = text.split()
56
+ if len(words) <= 5: # Don't process very short responses
57
+ return text
58
+
59
+ result = []
60
+ for i in range(len(words)):
61
+ # Check if this word starts a repeated phrase
62
+ if i < len(words) - 3: # Need at least 3 words to check for repetition
63
+ # Check if next 3+ words appear earlier in the text
64
+ is_repetition = False
65
+ for j in range(3, min(10, len(words) - i)): # Check phrases of length 3 to 10
66
+ phrase = " ".join(words[i:i+j])
67
+ if phrase in " ".join(result):
68
+ is_repetition = True
69
+ break
70
+
71
+ if not is_repetition:
72
+ result.append(words[i])
73
+ else:
74
+ result.append(words[i])
75
+
76
+ return " ".join(result)
77
+
78
  @app.route("/talk", methods=["POST"])
79
  def talk():
80
  if "audio" not in request.files:
81
  return jsonify({"error": "No audio file"}), 400
82
+
83
  # Save audio
84
  audio_file = request.files["audio"]
85
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
86
  audio_path = tmp.name
87
  audio_file.save(audio_path)
88
+
89
  # Transcribe
90
  segments, _ = whisper_model.transcribe(audio_path)
91
  transcription = "".join([seg.text for seg in segments])
92
+
93
  # Generate response
94
+ raw_response = llm(transcription)[0]["generated_text"]
95
+
96
+ # Process to get clean, short response
97
+ final_response = process_response(transcription, raw_response)
98
+
99
  # Synthesize speech
100
  tts_audio_path = audio_path.replace(".wav", "_reply.wav")
101
+ tts.tts_to_file(text=final_response, file_path=tts_audio_path)
102
+
103
  return send_file(tts_audio_path, mimetype="audio/wav")
104
 
105
  @app.route("/chat", methods=["POST"])
 
107
  data = request.get_json()
108
  if not data or "text" not in data:
109
  return jsonify({"error": "Missing 'text' in request body"}), 400
110
+
111
  user_input = data["text"]
112
+ raw_response = llm(user_input)[0]["generated_text"]
113
+
114
+ # Process to get clean, short response
115
+ final_response = process_response(user_input, raw_response)
116
+
117
+ return jsonify({"response": final_response})
118
 
119
  @app.route("/")
120
  def index():