siddhartharyaai commited on
Commit
84a3c5a
·
verified ·
1 Parent(s): 4df1c08

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +58 -16
utils.py CHANGED
@@ -17,8 +17,12 @@ import numpy as np
17
  import torch
18
  import random
19
 
 
 
 
20
  class DialogueItem(BaseModel):
21
- speaker: Literal["Jane", "John"]
 
22
  text: str
23
 
24
  class Dialogue(BaseModel):
@@ -264,20 +268,28 @@ def fetch_article_text(link: str) -> str:
264
  print(f"[ERROR] Error fetching article text: {e}")
265
  return ""
266
 
267
- def generate_script(system_prompt: str, input_text: str, tone: str, target_length: str):
 
 
 
 
 
268
  """
269
  Sends the system_prompt plus input_text to the Groq LLM to generate a
270
  multi-speaker Dialogue in JSON. We parse and return it as a Dialogue object.
271
 
272
- QUICK FIX ADDED:
273
- - If the LLM returns speakers other than "Jane" or "John,"
274
- we force them to "Jane" to satisfy the Pydantic literal constraint.
 
 
 
 
275
  """
276
  print("[LOG] Generating script with tone:", tone, "and length:", target_length)
277
  groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
278
 
279
- # Instead of a fixed mapping, parse the numeric minutes from target_length if possible
280
- # E.g. "3 Mins" -> 3 -> approximate word range
281
  words_per_minute = 150
282
  numeric_minutes = 3
283
  match = re.search(r"(\d+)", target_length)
@@ -337,14 +349,38 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
337
 
338
  json_str = raw_content[start_index:end_index+1].strip()
339
 
340
- # --- QUICK FIX: Post-process to ensure only "Jane"/"John" as speakers ---
341
  try:
342
  data = json.loads(json_str)
343
- for d in data.get("dialogue", []):
344
- if d.get("speaker") not in ["Jane", "John"]:
345
- d["speaker"] = "Jane" # Force to "Jane" or "John" (you could alternate if desired)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
- return Dialogue(**data)
348
 
349
  except json.JSONDecodeError as e:
350
  print("[ERROR] JSON decoding (format) failed:", e)
@@ -353,7 +389,6 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
353
  print("[ERROR] JSON decoding failed:", e)
354
  raise ValueError(f"Failed to parse dialogue: {str(e)}")
355
 
356
- # REPLACE the YTDLP-based approach with the RapidAPI approach
357
  def transcribe_youtube_video(video_url: str) -> str:
358
  """
359
  Transcribe the given YouTube video by calling the RapidAPI 'youtube-transcriptor' endpoint.
@@ -425,8 +460,9 @@ def generate_audio_mp3(text: str, speaker: str) -> str:
425
  # Deepgram TTS endpoint
426
  deepgram_api_url = "https://api.deepgram.com/v1/speak"
427
  params = {
428
- "model": "aura-asteria-en", # default
429
  }
 
430
  if speaker == "John":
431
  params["model"] = "aura-zeus-en"
432
 
@@ -480,6 +516,8 @@ def _preprocess_text_for_tts(text: str, speaker: str) -> str:
480
  Enhances text for natural-sounding TTS by handling abbreviations,
481
  punctuation, and intelligent filler insertion.
482
  Adjustments are made based on the speaker to optimize output quality.
 
 
483
  """
484
  # 1) Hyphens -> spaces
485
  text = re.sub(r"-", " ", text)
@@ -494,10 +532,14 @@ def _preprocess_text_for_tts(text: str, speaker: str) -> str:
494
 
495
  text = re.sub(r"\d+\.\d+", convert_decimal, text)
496
 
497
- # 3) Abbreviations (e.g., NASA -> N A S A)
 
498
  def expand_abbreviations(match):
499
  abbrev = match.group()
500
- # Check if it's plural
 
 
 
501
  if abbrev.endswith('s') and abbrev[:-1].isupper():
502
  singular = abbrev[:-1]
503
  expanded = " ".join(list(singular)) + "s"
 
17
  import torch
18
  import random
19
 
20
+ # ---------------------------------------------------------------------
21
+ # Updated: DialogueItem now has an extra field `display_speaker`
22
+ # ---------------------------------------------------------------------
23
  class DialogueItem(BaseModel):
24
+ speaker: Literal["Jane", "John"] # Used internally for TTS voice
25
+ display_speaker: str = "Jane" # The name shown in the user-facing transcript
26
  text: str
27
 
28
  class Dialogue(BaseModel):
 
268
  print(f"[ERROR] Error fetching article text: {e}")
269
  return ""
270
 
271
+ # ---------------------------------------------------------------------
272
+ # Pass host_name & guest_name so we can do "female voice" vs "male voice"
273
+ # and display_speaker vs. speaker
274
+ # ---------------------------------------------------------------------
275
+ def generate_script(system_prompt: str, input_text: str, tone: str, target_length: str,
276
+ host_name: str = "Jane", guest_name: str = "John"):
277
  """
278
  Sends the system_prompt plus input_text to the Groq LLM to generate a
279
  multi-speaker Dialogue in JSON. We parse and return it as a Dialogue object.
280
 
281
+ Logic:
282
+ - We parse the LLM's raw speaker name (e.g., "Angela", "Dimitris").
283
+ - If it matches the host_name, we set speaker="Jane" (female voice),
284
+ display_speaker = host_name.
285
+ - If it matches the guest_name, we set speaker="John" (male voice),
286
+ display_speaker = guest_name.
287
+ - If we can't match, default to "Jane" for speaker, but keep display_speaker as whatever LLM returned.
288
  """
289
  print("[LOG] Generating script with tone:", tone, "and length:", target_length)
290
  groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
291
 
292
+ # Instead of a fixed mapping, parse numeric minutes from target_length if possible
 
293
  words_per_minute = 150
294
  numeric_minutes = 3
295
  match = re.search(r"(\d+)", target_length)
 
349
 
350
  json_str = raw_content[start_index:end_index+1].strip()
351
 
 
352
  try:
353
  data = json.loads(json_str)
354
+ dialogue_list = data.get("dialogue", [])
355
+
356
+ # Post-process to ensure correct TTS speaker + custom display name
357
+ for d in dialogue_list:
358
+ raw_speaker = d.get("speaker", "Jane")
359
+ text_line = d.get("text", "")
360
+
361
+ # If raw_speaker matches host_name (case-insensitive), speaker = "Jane"
362
+ if raw_speaker.lower() == host_name.lower():
363
+ d["speaker"] = "Jane"
364
+ d["display_speaker"] = host_name
365
+ # If raw_speaker matches guest_name, speaker = "John"
366
+ elif raw_speaker.lower() == guest_name.lower():
367
+ d["speaker"] = "John"
368
+ d["display_speaker"] = guest_name
369
+ else:
370
+ # Otherwise default: we assume it's host
371
+ d["speaker"] = "Jane"
372
+ d["display_speaker"] = raw_speaker # keep the original name for display
373
+
374
+ # Now build the Dialogue object
375
+ # For any item that doesn't have display_speaker, fallback to "Jane"
376
+ new_dialogue_items = []
377
+ for d in dialogue_list:
378
+ if "display_speaker" not in d:
379
+ d["display_speaker"] = d["speaker"] # fallback
380
+ # Convert dict -> DialogueItem
381
+ new_dialogue_items.append(DialogueItem(**d))
382
 
383
+ return Dialogue(dialogue=new_dialogue_items)
384
 
385
  except json.JSONDecodeError as e:
386
  print("[ERROR] JSON decoding (format) failed:", e)
 
389
  print("[ERROR] JSON decoding failed:", e)
390
  raise ValueError(f"Failed to parse dialogue: {str(e)}")
391
 
 
392
  def transcribe_youtube_video(video_url: str) -> str:
393
  """
394
  Transcribe the given YouTube video by calling the RapidAPI 'youtube-transcriptor' endpoint.
 
460
  # Deepgram TTS endpoint
461
  deepgram_api_url = "https://api.deepgram.com/v1/speak"
462
  params = {
463
+ "model": "aura-asteria-en", # default female
464
  }
465
+ # If speaker == "John", use male voice
466
  if speaker == "John":
467
  params["model"] = "aura-zeus-en"
468
 
 
516
  Enhances text for natural-sounding TTS by handling abbreviations,
517
  punctuation, and intelligent filler insertion.
518
  Adjustments are made based on the speaker to optimize output quality.
519
+
520
+ New: We'll handle "SaaS" so that it is read as "S A A S".
521
  """
522
  # 1) Hyphens -> spaces
523
  text = re.sub(r"-", " ", text)
 
532
 
533
  text = re.sub(r"\d+\.\d+", convert_decimal, text)
534
 
535
+ # 3) Abbreviations (e.g., NASA -> N A S A).
536
+ # We'll also handle "SaaS" -> "S A A S" specifically.
537
  def expand_abbreviations(match):
538
  abbrev = match.group()
539
+ # Special handling for "SaaS" -> "S A A S"
540
+ if abbrev.lower() == "saas":
541
+ return "S A A S"
542
+ # Check if it's plural with capital letters
543
  if abbrev.endswith('s') and abbrev[:-1].isupper():
544
  singular = abbrev[:-1]
545
  expanded = " ".join(list(singular)) + "s"