Spaces:
Running
Running
Update utils.py
Browse files
utils.py
CHANGED
@@ -17,8 +17,12 @@ import numpy as np
|
|
17 |
import torch
|
18 |
import random
|
19 |
|
|
|
|
|
|
|
20 |
class DialogueItem(BaseModel):
|
21 |
-
speaker: Literal["Jane", "John"]
|
|
|
22 |
text: str
|
23 |
|
24 |
class Dialogue(BaseModel):
|
@@ -264,20 +268,28 @@ def fetch_article_text(link: str) -> str:
|
|
264 |
print(f"[ERROR] Error fetching article text: {e}")
|
265 |
return ""
|
266 |
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
268 |
"""
|
269 |
Sends the system_prompt plus input_text to the Groq LLM to generate a
|
270 |
multi-speaker Dialogue in JSON. We parse and return it as a Dialogue object.
|
271 |
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
275 |
"""
|
276 |
print("[LOG] Generating script with tone:", tone, "and length:", target_length)
|
277 |
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
278 |
|
279 |
-
# Instead of a fixed mapping, parse
|
280 |
-
# E.g. "3 Mins" -> 3 -> approximate word range
|
281 |
words_per_minute = 150
|
282 |
numeric_minutes = 3
|
283 |
match = re.search(r"(\d+)", target_length)
|
@@ -337,14 +349,38 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
|
|
337 |
|
338 |
json_str = raw_content[start_index:end_index+1].strip()
|
339 |
|
340 |
-
# --- QUICK FIX: Post-process to ensure only "Jane"/"John" as speakers ---
|
341 |
try:
|
342 |
data = json.loads(json_str)
|
343 |
-
|
344 |
-
|
345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
|
347 |
-
return Dialogue(
|
348 |
|
349 |
except json.JSONDecodeError as e:
|
350 |
print("[ERROR] JSON decoding (format) failed:", e)
|
@@ -353,7 +389,6 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
|
|
353 |
print("[ERROR] JSON decoding failed:", e)
|
354 |
raise ValueError(f"Failed to parse dialogue: {str(e)}")
|
355 |
|
356 |
-
# REPLACE the YTDLP-based approach with the RapidAPI approach
|
357 |
def transcribe_youtube_video(video_url: str) -> str:
|
358 |
"""
|
359 |
Transcribe the given YouTube video by calling the RapidAPI 'youtube-transcriptor' endpoint.
|
@@ -425,8 +460,9 @@ def generate_audio_mp3(text: str, speaker: str) -> str:
|
|
425 |
# Deepgram TTS endpoint
|
426 |
deepgram_api_url = "https://api.deepgram.com/v1/speak"
|
427 |
params = {
|
428 |
-
"model": "aura-asteria-en", # default
|
429 |
}
|
|
|
430 |
if speaker == "John":
|
431 |
params["model"] = "aura-zeus-en"
|
432 |
|
@@ -480,6 +516,8 @@ def _preprocess_text_for_tts(text: str, speaker: str) -> str:
|
|
480 |
Enhances text for natural-sounding TTS by handling abbreviations,
|
481 |
punctuation, and intelligent filler insertion.
|
482 |
Adjustments are made based on the speaker to optimize output quality.
|
|
|
|
|
483 |
"""
|
484 |
# 1) Hyphens -> spaces
|
485 |
text = re.sub(r"-", " ", text)
|
@@ -494,10 +532,14 @@ def _preprocess_text_for_tts(text: str, speaker: str) -> str:
|
|
494 |
|
495 |
text = re.sub(r"\d+\.\d+", convert_decimal, text)
|
496 |
|
497 |
-
# 3) Abbreviations (e.g., NASA -> N A S A)
|
|
|
498 |
def expand_abbreviations(match):
|
499 |
abbrev = match.group()
|
500 |
-
#
|
|
|
|
|
|
|
501 |
if abbrev.endswith('s') and abbrev[:-1].isupper():
|
502 |
singular = abbrev[:-1]
|
503 |
expanded = " ".join(list(singular)) + "s"
|
|
|
17 |
import torch
|
18 |
import random
|
19 |
|
20 |
+
# ---------------------------------------------------------------------
|
21 |
+
# Updated: DialogueItem now has an extra field `display_speaker`
|
22 |
+
# ---------------------------------------------------------------------
|
23 |
class DialogueItem(BaseModel):
|
24 |
+
speaker: Literal["Jane", "John"] # Used internally for TTS voice
|
25 |
+
display_speaker: str = "Jane" # The name shown in the user-facing transcript
|
26 |
text: str
|
27 |
|
28 |
class Dialogue(BaseModel):
|
|
|
268 |
print(f"[ERROR] Error fetching article text: {e}")
|
269 |
return ""
|
270 |
|
271 |
+
# ---------------------------------------------------------------------
|
272 |
+
# Pass host_name & guest_name so we can do "female voice" vs "male voice"
|
273 |
+
# and display_speaker vs. speaker
|
274 |
+
# ---------------------------------------------------------------------
|
275 |
+
def generate_script(system_prompt: str, input_text: str, tone: str, target_length: str,
|
276 |
+
host_name: str = "Jane", guest_name: str = "John"):
|
277 |
"""
|
278 |
Sends the system_prompt plus input_text to the Groq LLM to generate a
|
279 |
multi-speaker Dialogue in JSON. We parse and return it as a Dialogue object.
|
280 |
|
281 |
+
Logic:
|
282 |
+
- We parse the LLM's raw speaker name (e.g., "Angela", "Dimitris").
|
283 |
+
- If it matches the host_name, we set speaker="Jane" (female voice),
|
284 |
+
display_speaker = host_name.
|
285 |
+
- If it matches the guest_name, we set speaker="John" (male voice),
|
286 |
+
display_speaker = guest_name.
|
287 |
+
- If we can't match, default to "Jane" for speaker, but keep display_speaker as whatever LLM returned.
|
288 |
"""
|
289 |
print("[LOG] Generating script with tone:", tone, "and length:", target_length)
|
290 |
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
291 |
|
292 |
+
# Instead of a fixed mapping, parse numeric minutes from target_length if possible
|
|
|
293 |
words_per_minute = 150
|
294 |
numeric_minutes = 3
|
295 |
match = re.search(r"(\d+)", target_length)
|
|
|
349 |
|
350 |
json_str = raw_content[start_index:end_index+1].strip()
|
351 |
|
|
|
352 |
try:
|
353 |
data = json.loads(json_str)
|
354 |
+
dialogue_list = data.get("dialogue", [])
|
355 |
+
|
356 |
+
# Post-process to ensure correct TTS speaker + custom display name
|
357 |
+
for d in dialogue_list:
|
358 |
+
raw_speaker = d.get("speaker", "Jane")
|
359 |
+
text_line = d.get("text", "")
|
360 |
+
|
361 |
+
# If raw_speaker matches host_name (case-insensitive), speaker = "Jane"
|
362 |
+
if raw_speaker.lower() == host_name.lower():
|
363 |
+
d["speaker"] = "Jane"
|
364 |
+
d["display_speaker"] = host_name
|
365 |
+
# If raw_speaker matches guest_name, speaker = "John"
|
366 |
+
elif raw_speaker.lower() == guest_name.lower():
|
367 |
+
d["speaker"] = "John"
|
368 |
+
d["display_speaker"] = guest_name
|
369 |
+
else:
|
370 |
+
# Otherwise default: we assume it's host
|
371 |
+
d["speaker"] = "Jane"
|
372 |
+
d["display_speaker"] = raw_speaker # keep the original name for display
|
373 |
+
|
374 |
+
# Now build the Dialogue object
|
375 |
+
# For any item that doesn't have display_speaker, fallback to "Jane"
|
376 |
+
new_dialogue_items = []
|
377 |
+
for d in dialogue_list:
|
378 |
+
if "display_speaker" not in d:
|
379 |
+
d["display_speaker"] = d["speaker"] # fallback
|
380 |
+
# Convert dict -> DialogueItem
|
381 |
+
new_dialogue_items.append(DialogueItem(**d))
|
382 |
|
383 |
+
return Dialogue(dialogue=new_dialogue_items)
|
384 |
|
385 |
except json.JSONDecodeError as e:
|
386 |
print("[ERROR] JSON decoding (format) failed:", e)
|
|
|
389 |
print("[ERROR] JSON decoding failed:", e)
|
390 |
raise ValueError(f"Failed to parse dialogue: {str(e)}")
|
391 |
|
|
|
392 |
def transcribe_youtube_video(video_url: str) -> str:
|
393 |
"""
|
394 |
Transcribe the given YouTube video by calling the RapidAPI 'youtube-transcriptor' endpoint.
|
|
|
460 |
# Deepgram TTS endpoint
|
461 |
deepgram_api_url = "https://api.deepgram.com/v1/speak"
|
462 |
params = {
|
463 |
+
"model": "aura-asteria-en", # default female
|
464 |
}
|
465 |
+
# If speaker == "John", use male voice
|
466 |
if speaker == "John":
|
467 |
params["model"] = "aura-zeus-en"
|
468 |
|
|
|
516 |
Enhances text for natural-sounding TTS by handling abbreviations,
|
517 |
punctuation, and intelligent filler insertion.
|
518 |
Adjustments are made based on the speaker to optimize output quality.
|
519 |
+
|
520 |
+
New: We'll handle "SaaS" so that it is read as "S A A S".
|
521 |
"""
|
522 |
# 1) Hyphens -> spaces
|
523 |
text = re.sub(r"-", " ", text)
|
|
|
532 |
|
533 |
text = re.sub(r"\d+\.\d+", convert_decimal, text)
|
534 |
|
535 |
+
# 3) Abbreviations (e.g., NASA -> N A S A).
|
536 |
+
# We'll also handle "SaaS" -> "S A A S" specifically.
|
537 |
def expand_abbreviations(match):
|
538 |
abbrev = match.group()
|
539 |
+
# Special handling for "SaaS" -> "S A A S"
|
540 |
+
if abbrev.lower() == "saas":
|
541 |
+
return "S A A S"
|
542 |
+
# Check if it's plural with capital letters
|
543 |
if abbrev.endswith('s') and abbrev[:-1].isupper():
|
544 |
singular = abbrev[:-1]
|
545 |
expanded = " ".join(list(singular)) + "s"
|