Spaces:
Running
Running
Update utils.py
Browse files
utils.py
CHANGED
@@ -15,6 +15,7 @@ import tiktoken
|
|
15 |
from groq import Groq # Ensure Groq client is imported
|
16 |
import numpy as np
|
17 |
import torch # Added to check CUDA availability
|
|
|
18 |
|
19 |
class DialogueItem(BaseModel):
|
20 |
speaker: Literal["Jane", "John"]
|
@@ -143,7 +144,6 @@ def research_topic(topic: str) -> str:
|
|
143 |
if article_text:
|
144 |
summary_parts.append(f"From {name}: {article_text}")
|
145 |
else:
|
146 |
-
# If no main text extracted, use title/desc
|
147 |
summary_parts.append(f"From {name}: {title} - {desc}")
|
148 |
except Exception as e:
|
149 |
print(f"[ERROR] Error fetching from {name} RSS feed:", e)
|
@@ -162,7 +162,6 @@ def research_topic(topic: str) -> str:
|
|
162 |
print("[ERROR] Failed to retrieve additional information from LLM.")
|
163 |
|
164 |
if not aggregated_info:
|
165 |
-
# No info found at all
|
166 |
print("[LOG] No information found for the topic.")
|
167 |
return f"Sorry, I couldn't find recent information on '{topic}'."
|
168 |
|
@@ -201,7 +200,6 @@ def fetch_rss_feed(feed_url: str) -> list:
|
|
201 |
if resp.status_code != 200:
|
202 |
print(f"[ERROR] Failed to fetch RSS feed: {feed_url} with status code {resp.status_code}")
|
203 |
return []
|
204 |
-
# Use html.parser instead of xml to avoid needing lxml or other parsers.
|
205 |
soup = BeautifulSoup(resp.content, "html.parser")
|
206 |
items = soup.find_all("item")
|
207 |
print(f"[LOG] Number of items fetched from {feed_url}: {len(items)}")
|
@@ -246,10 +244,8 @@ def fetch_article_text(link: str) -> str:
|
|
246 |
print(f"[ERROR] Failed to fetch article from link: {link} with status code {resp.status_code}")
|
247 |
return ""
|
248 |
soup = BeautifulSoup(resp.text, 'html.parser')
|
249 |
-
# This is site-specific. We'll try a generic approach:
|
250 |
-
# Just take all paragraphs:
|
251 |
paragraphs = soup.find_all("p")
|
252 |
-
text = " ".join(p.get_text() for p in paragraphs[:5])
|
253 |
print("[LOG] Article text fetched successfully.")
|
254 |
return text.strip()
|
255 |
except Exception as e:
|
@@ -260,7 +256,6 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
|
|
260 |
print("[LOG] Generating script with tone:", tone, "and length:", target_length)
|
261 |
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
262 |
|
263 |
-
# Map target_length to word ranges
|
264 |
length_mapping = {
|
265 |
"1-3 Mins": (200, 450),
|
266 |
"3-5 Mins": (450, 750),
|
@@ -269,7 +264,6 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
|
|
269 |
}
|
270 |
min_words, max_words = length_mapping.get(target_length, (200, 450))
|
271 |
|
272 |
-
# Adjust tone description for clarity in prompt
|
273 |
tone_description = {
|
274 |
"Humorous": "funny and exciting, makes people chuckle",
|
275 |
"Formal": "business-like, well-structured, professional",
|
@@ -279,7 +273,6 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
|
|
279 |
|
280 |
chosen_tone = tone_description.get(tone, "casual")
|
281 |
|
282 |
-
# Construct the prompt with clear instructions for JSON output
|
283 |
prompt = (
|
284 |
f"{system_prompt}\n"
|
285 |
f"TONE: {chosen_tone}\n"
|
@@ -300,7 +293,7 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
|
|
300 |
"}"
|
301 |
)
|
302 |
print("[LOG] Sending prompt to Groq:")
|
303 |
-
print(prompt)
|
304 |
|
305 |
try:
|
306 |
response = groq_client.chat.completions.create(
|
@@ -313,14 +306,11 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
|
|
313 |
print("[ERROR] Groq API error:", e)
|
314 |
raise ValueError(f"Error communicating with Groq API: {str(e)}")
|
315 |
|
316 |
-
# Log the raw response content for debugging
|
317 |
raw_content = response.choices[0].message.content.strip()
|
318 |
print("[DEBUG] Raw API response content:")
|
319 |
print(raw_content)
|
320 |
|
321 |
-
# Attempt to extract JSON from the response
|
322 |
content = raw_content.replace('```json', '').replace('```', '').strip()
|
323 |
-
|
324 |
start_index = content.find('{')
|
325 |
end_index = content.rfind('}')
|
326 |
|
@@ -345,55 +335,120 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
|
|
345 |
print(content)
|
346 |
raise ValueError(f"Failed to parse dialogue: {str(e)}")
|
347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
def generate_audio_mp3(text: str, speaker: str) -> str:
|
349 |
try:
|
350 |
print(f"[LOG] Generating audio for speaker: {speaker}")
|
351 |
-
|
|
|
|
|
|
|
352 |
# Define Deepgram API endpoint
|
353 |
deepgram_api_url = "https://api.deepgram.com/v1/speak"
|
354 |
|
355 |
# Prepare query parameters
|
356 |
params = {
|
357 |
"model": "aura-asteria-en", # Default model; adjust if needed
|
358 |
-
# You can add more parameters here as needed, e.g., bit_rate, sample_rate, etc.
|
359 |
}
|
360 |
|
361 |
# Override model if needed based on speaker
|
362 |
if speaker == "Jane":
|
363 |
-
params["model"] = "aura-asteria-en"
|
364 |
elif speaker == "John":
|
365 |
-
params["model"] = "aura-perseus-en"
|
366 |
else:
|
367 |
raise ValueError(f"Unknown speaker: {speaker}")
|
368 |
|
369 |
-
# Prepare headers
|
370 |
headers = {
|
371 |
-
"Accept": "audio/mpeg",
|
372 |
"Content-Type": "application/json",
|
373 |
"Authorization": f"Token {os.environ.get('DEEPGRAM_API_KEY')}"
|
374 |
}
|
375 |
|
376 |
-
# Prepare body
|
377 |
body = {
|
378 |
"text": text
|
379 |
}
|
380 |
|
381 |
print("[LOG] Sending TTS request to Deepgram...")
|
382 |
-
# Make the POST request to Deepgram's TTS API
|
383 |
response = requests.post(deepgram_api_url, params=params, headers=headers, json=body, stream=True)
|
384 |
|
385 |
if response.status_code != 200:
|
386 |
print(f"[ERROR] Deepgram TTS API returned status code {response.status_code}: {response.text}")
|
387 |
raise ValueError(f"Deepgram TTS API error: {response.status_code} - {response.text}")
|
388 |
|
389 |
-
# Verify Content-Type
|
390 |
content_type = response.headers.get('Content-Type', '')
|
391 |
if 'audio/mpeg' not in content_type:
|
392 |
print("[ERROR] Unexpected Content-Type received from Deepgram:", content_type)
|
393 |
print("[ERROR] Response content:", response.text)
|
394 |
raise ValueError("Unexpected Content-Type received from Deepgram.")
|
395 |
|
396 |
-
# Save the streamed audio to a temporary MP3 file
|
397 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as mp3_file:
|
398 |
for chunk in response.iter_content(chunk_size=8192):
|
399 |
if chunk:
|
@@ -405,19 +460,10 @@ def generate_audio_mp3(text: str, speaker: str) -> str:
|
|
405 |
audio_seg = AudioSegment.from_file(mp3_temp_path, format="mp3")
|
406 |
audio_seg = effects.normalize(audio_seg)
|
407 |
|
408 |
-
# Removed pitch shifting for male voice
|
409 |
-
# Previously:
|
410 |
-
# if speaker == "John":
|
411 |
-
# semitones = -5 # Shift down by 5 semitones for a deeper voice
|
412 |
-
# audio_seg = pitch_shift(audio_seg, semitones=semitones)
|
413 |
-
# print(f"[LOG] Applied pitch shift to John's voice by {semitones} semitones.")
|
414 |
-
|
415 |
-
# Export the final audio as MP3
|
416 |
final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
|
417 |
audio_seg.export(final_mp3_path, format="mp3")
|
418 |
print("[LOG] Audio post-processed and saved at:", final_mp3_path)
|
419 |
|
420 |
-
# Clean up the initial MP3 file
|
421 |
if os.path.exists(mp3_temp_path):
|
422 |
os.remove(mp3_temp_path)
|
423 |
print(f"[LOG] Removed temporary MP3 file: {mp3_temp_path}")
|
@@ -462,7 +508,6 @@ def transcribe_youtube_video(video_url: str) -> str:
|
|
462 |
print("[ERROR] ASR transcription error:", e)
|
463 |
raise ValueError(f"Error transcribing YouTube video: {str(e)}")
|
464 |
finally:
|
465 |
-
# Clean up the downloaded audio file
|
466 |
if os.path.exists(audio_file):
|
467 |
os.remove(audio_file)
|
468 |
print(f"[LOG] Removed temporary audio file: {audio_file}")
|
|
|
15 |
from groq import Groq # Ensure Groq client is imported
|
16 |
import numpy as np
|
17 |
import torch # Added to check CUDA availability
|
18 |
+
import random
|
19 |
|
20 |
class DialogueItem(BaseModel):
|
21 |
speaker: Literal["Jane", "John"]
|
|
|
144 |
if article_text:
|
145 |
summary_parts.append(f"From {name}: {article_text}")
|
146 |
else:
|
|
|
147 |
summary_parts.append(f"From {name}: {title} - {desc}")
|
148 |
except Exception as e:
|
149 |
print(f"[ERROR] Error fetching from {name} RSS feed:", e)
|
|
|
162 |
print("[ERROR] Failed to retrieve additional information from LLM.")
|
163 |
|
164 |
if not aggregated_info:
|
|
|
165 |
print("[LOG] No information found for the topic.")
|
166 |
return f"Sorry, I couldn't find recent information on '{topic}'."
|
167 |
|
|
|
200 |
if resp.status_code != 200:
|
201 |
print(f"[ERROR] Failed to fetch RSS feed: {feed_url} with status code {resp.status_code}")
|
202 |
return []
|
|
|
203 |
soup = BeautifulSoup(resp.content, "html.parser")
|
204 |
items = soup.find_all("item")
|
205 |
print(f"[LOG] Number of items fetched from {feed_url}: {len(items)}")
|
|
|
244 |
print(f"[ERROR] Failed to fetch article from link: {link} with status code {resp.status_code}")
|
245 |
return ""
|
246 |
soup = BeautifulSoup(resp.text, 'html.parser')
|
|
|
|
|
247 |
paragraphs = soup.find_all("p")
|
248 |
+
text = " ".join(p.get_text() for p in paragraphs[:5])
|
249 |
print("[LOG] Article text fetched successfully.")
|
250 |
return text.strip()
|
251 |
except Exception as e:
|
|
|
256 |
print("[LOG] Generating script with tone:", tone, "and length:", target_length)
|
257 |
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
258 |
|
|
|
259 |
length_mapping = {
|
260 |
"1-3 Mins": (200, 450),
|
261 |
"3-5 Mins": (450, 750),
|
|
|
264 |
}
|
265 |
min_words, max_words = length_mapping.get(target_length, (200, 450))
|
266 |
|
|
|
267 |
tone_description = {
|
268 |
"Humorous": "funny and exciting, makes people chuckle",
|
269 |
"Formal": "business-like, well-structured, professional",
|
|
|
273 |
|
274 |
chosen_tone = tone_description.get(tone, "casual")
|
275 |
|
|
|
276 |
prompt = (
|
277 |
f"{system_prompt}\n"
|
278 |
f"TONE: {chosen_tone}\n"
|
|
|
293 |
"}"
|
294 |
)
|
295 |
print("[LOG] Sending prompt to Groq:")
|
296 |
+
print(prompt)
|
297 |
|
298 |
try:
|
299 |
response = groq_client.chat.completions.create(
|
|
|
306 |
print("[ERROR] Groq API error:", e)
|
307 |
raise ValueError(f"Error communicating with Groq API: {str(e)}")
|
308 |
|
|
|
309 |
raw_content = response.choices[0].message.content.strip()
|
310 |
print("[DEBUG] Raw API response content:")
|
311 |
print(raw_content)
|
312 |
|
|
|
313 |
content = raw_content.replace('```json', '').replace('```', '').strip()
|
|
|
314 |
start_index = content.find('{')
|
315 |
end_index = content.rfind('}')
|
316 |
|
|
|
335 |
print(content)
|
336 |
raise ValueError(f"Failed to parse dialogue: {str(e)}")
|
337 |
|
338 |
+
# ----------------------------------------------------------------------
|
339 |
+
# We ONLY modify the generate_audio_mp3 flow below to insert random filler words
|
340 |
+
# and modify punctuation (.,!?) for more natural TTS pauses and intonation.
|
341 |
+
# ----------------------------------------------------------------------
|
342 |
+
|
343 |
+
def _make_text_sound_more_human(text: str) -> str:
|
344 |
+
"""
|
345 |
+
Inserts small filler words and adds extra punctuation to encourage
|
346 |
+
natural-sounding pauses at commas, periods, exclamations, and question marks.
|
347 |
+
"""
|
348 |
+
|
349 |
+
# Filler words or short phrases
|
350 |
+
fillers = ["uh", "um", "ah", "hmm", "you know", "well", "I mean", "like"]
|
351 |
+
|
352 |
+
# 1) Split text by punctuation but keep the punctuation in the result
|
353 |
+
# We'll handle ".", "?", "!", and commas:
|
354 |
+
pattern = r'([.,?!])'
|
355 |
+
parts = re.split(pattern, text)
|
356 |
+
|
357 |
+
# 2) Process each chunk, occasionally inserting filler words or extra punctuation
|
358 |
+
processed_chunks = []
|
359 |
+
for i in range(len(parts)):
|
360 |
+
chunk = parts[i].strip()
|
361 |
+
|
362 |
+
# If the chunk is punctuation, keep it
|
363 |
+
if chunk in [".", ",", "?", "!"]:
|
364 |
+
# Possibly turn "." into "..." or add "..." after "?"
|
365 |
+
if chunk == "." and random.random() < 0.5:
|
366 |
+
chunk = "..."
|
367 |
+
elif chunk == "?" and random.random() < 0.3:
|
368 |
+
# Sometimes add "?!"
|
369 |
+
chunk = "?!"
|
370 |
+
elif chunk == "!" and random.random() < 0.3:
|
371 |
+
# Sometimes add "!!" for more emphasis
|
372 |
+
chunk = "!!"
|
373 |
+
processed_chunks.append(chunk)
|
374 |
+
continue
|
375 |
+
|
376 |
+
# Sometimes insert a filler at the start or mid-chunk
|
377 |
+
if chunk and random.random() < 0.3:
|
378 |
+
filler = random.choice(fillers)
|
379 |
+
# Insert at the beginning or in the middle
|
380 |
+
if random.random() < 0.5:
|
381 |
+
chunk = f"{filler}, {chunk}"
|
382 |
+
else:
|
383 |
+
# Insert near the middle
|
384 |
+
words = chunk.split()
|
385 |
+
mid = len(words) // 2
|
386 |
+
chunk = " ".join(words[:mid] + [f"{filler},"] + words[mid:])
|
387 |
+
|
388 |
+
processed_chunks.append(chunk)
|
389 |
+
|
390 |
+
# 3) Rejoin them carefully with a space or nothing
|
391 |
+
# We'll add a small space after punctuation, so TTS sees them as separate tokens
|
392 |
+
out_text = []
|
393 |
+
for i in range(len(processed_chunks)):
|
394 |
+
if i == 0:
|
395 |
+
out_text.append(processed_chunks[i])
|
396 |
+
else:
|
397 |
+
# If the previous chunk was punctuation or the current chunk is punctuation
|
398 |
+
if processed_chunks[i] in [".", "...", "?", "?!", "!", "!!", ","]:
|
399 |
+
out_text.append(processed_chunks[i])
|
400 |
+
else:
|
401 |
+
out_text.append(" " + processed_chunks[i])
|
402 |
+
|
403 |
+
final_text = "".join(out_text)
|
404 |
+
return final_text.strip()
|
405 |
+
|
406 |
def generate_audio_mp3(text: str, speaker: str) -> str:
|
407 |
try:
|
408 |
print(f"[LOG] Generating audio for speaker: {speaker}")
|
409 |
+
|
410 |
+
# Make text more "human-like"
|
411 |
+
text = _make_text_sound_more_human(text)
|
412 |
+
|
413 |
# Define Deepgram API endpoint
|
414 |
deepgram_api_url = "https://api.deepgram.com/v1/speak"
|
415 |
|
416 |
# Prepare query parameters
|
417 |
params = {
|
418 |
"model": "aura-asteria-en", # Default model; adjust if needed
|
|
|
419 |
}
|
420 |
|
421 |
# Override model if needed based on speaker
|
422 |
if speaker == "Jane":
|
423 |
+
params["model"] = "aura-asteria-en"
|
424 |
elif speaker == "John":
|
425 |
+
params["model"] = "aura-perseus-en"
|
426 |
else:
|
427 |
raise ValueError(f"Unknown speaker: {speaker}")
|
428 |
|
|
|
429 |
headers = {
|
430 |
+
"Accept": "audio/mpeg",
|
431 |
"Content-Type": "application/json",
|
432 |
"Authorization": f"Token {os.environ.get('DEEPGRAM_API_KEY')}"
|
433 |
}
|
434 |
|
|
|
435 |
body = {
|
436 |
"text": text
|
437 |
}
|
438 |
|
439 |
print("[LOG] Sending TTS request to Deepgram...")
|
|
|
440 |
response = requests.post(deepgram_api_url, params=params, headers=headers, json=body, stream=True)
|
441 |
|
442 |
if response.status_code != 200:
|
443 |
print(f"[ERROR] Deepgram TTS API returned status code {response.status_code}: {response.text}")
|
444 |
raise ValueError(f"Deepgram TTS API error: {response.status_code} - {response.text}")
|
445 |
|
|
|
446 |
content_type = response.headers.get('Content-Type', '')
|
447 |
if 'audio/mpeg' not in content_type:
|
448 |
print("[ERROR] Unexpected Content-Type received from Deepgram:", content_type)
|
449 |
print("[ERROR] Response content:", response.text)
|
450 |
raise ValueError("Unexpected Content-Type received from Deepgram.")
|
451 |
|
|
|
452 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as mp3_file:
|
453 |
for chunk in response.iter_content(chunk_size=8192):
|
454 |
if chunk:
|
|
|
460 |
audio_seg = AudioSegment.from_file(mp3_temp_path, format="mp3")
|
461 |
audio_seg = effects.normalize(audio_seg)
|
462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
|
464 |
audio_seg.export(final_mp3_path, format="mp3")
|
465 |
print("[LOG] Audio post-processed and saved at:", final_mp3_path)
|
466 |
|
|
|
467 |
if os.path.exists(mp3_temp_path):
|
468 |
os.remove(mp3_temp_path)
|
469 |
print(f"[LOG] Removed temporary MP3 file: {mp3_temp_path}")
|
|
|
508 |
print("[ERROR] ASR transcription error:", e)
|
509 |
raise ValueError(f"Error transcribing YouTube video: {str(e)}")
|
510 |
finally:
|
|
|
511 |
if os.path.exists(audio_file):
|
512 |
os.remove(audio_file)
|
513 |
print(f"[LOG] Removed temporary audio file: {audio_file}")
|