siddhartharyaai commited on
Commit
2bcba5d
·
verified ·
1 Parent(s): 5198e6d

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +51 -101
utils.py CHANGED
@@ -18,14 +18,14 @@ import torch
18
  import random
19
 
20
  class DialogueItem(BaseModel):
21
- speaker: Literal["Jane", "John"] # TTS voice
22
- display_speaker: str = "Jane" # For display in transcript
23
  text: str
24
 
25
  class Dialogue(BaseModel):
26
  dialogue: List[DialogueItem]
27
 
28
- # Initialize Whisper (unused for YouTube with RapidAPI)
29
  asr_pipeline = pipeline(
30
  "automatic-speech-recognition",
31
  model="openai/whisper-tiny.en",
@@ -33,10 +33,6 @@ asr_pipeline = pipeline(
33
  )
34
 
35
  def truncate_text(text, max_tokens=2048):
36
- """
37
- If the text exceeds the max token limit (approx. 2,048), truncate it
38
- to avoid exceeding the model's context window.
39
- """
40
  print("[LOG] Truncating text if needed.")
41
  tokenizer = tiktoken.get_encoding("cl100k_base")
42
  tokens = tokenizer.encode(text)
@@ -46,10 +42,6 @@ def truncate_text(text, max_tokens=2048):
46
  return text
47
 
48
  def extract_text_from_url(url):
49
- """
50
- Fetches and extracts readable text from a given URL
51
- (stripping out scripts, styles, etc.).
52
- """
53
  print("[LOG] Extracting text from URL:", url)
54
  try:
55
  headers = {
@@ -74,29 +66,17 @@ def extract_text_from_url(url):
74
  return ""
75
 
76
  def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment:
77
- """
78
- Shifts the pitch of an AudioSegment by a given number of semitones.
79
- Positive semitones shift the pitch up, negative shifts it down.
80
- """
81
  print(f"[LOG] Shifting pitch by {semitones} semitones.")
82
  new_sample_rate = int(audio.frame_rate * (2.0 ** (semitones / 12.0)))
83
  shifted_audio = audio._spawn(audio.raw_data, overrides={'frame_rate': new_sample_rate})
84
  return shifted_audio.set_frame_rate(audio.frame_rate)
85
 
86
  def is_sufficient(text: str, min_word_count: int = 500) -> bool:
87
- """
88
- Checks if the fetched text meets our sufficiency criteria
89
- (e.g., at least 500 words).
90
- """
91
  word_count = len(text.split())
92
  print(f"[DEBUG] Aggregated word count: {word_count}")
93
  return word_count >= min_word_count
94
 
95
  def query_llm_for_additional_info(topic: str, existing_text: str) -> str:
96
- """
97
- Queries the Groq API to retrieve more info from the LLM's knowledge base.
98
- Appends it to our aggregated info if found.
99
- """
100
  print("[LOG] Querying LLM for additional information.")
101
  system_prompt = (
102
  "You are an AI assistant with extensive knowledge up to 2023-10. "
@@ -122,10 +102,6 @@ def query_llm_for_additional_info(topic: str, existing_text: str) -> str:
122
  return additional_info
123
 
124
  def research_topic(topic: str) -> str:
125
- """
126
- Gathers info from various RSS feeds and Wikipedia. If needed, queries the LLM
127
- for more data if the aggregated text is insufficient.
128
- """
129
  sources = {
130
  "BBC": "https://feeds.bbci.co.uk/news/rss.xml",
131
  "CNN": "http://rss.cnn.com/rss/edition.rss",
@@ -144,7 +120,6 @@ def research_topic(topic: str) -> str:
144
  if wiki_summary:
145
  summary_parts.append(f"From Wikipedia: {wiki_summary}")
146
 
147
- # For each RSS feed
148
  for name, feed_url in sources.items():
149
  try:
150
  items = fetch_rss_feed(feed_url)
@@ -165,7 +140,6 @@ def research_topic(topic: str) -> str:
165
  print("[DEBUG] Aggregated info from primary sources:")
166
  print(aggregated_info)
167
 
168
- # Fallback to LLM if insufficient
169
  if not is_sufficient(aggregated_info):
170
  print("[LOG] Insufficient info from primary sources. Fallback to LLM.")
171
  additional_info = query_llm_for_additional_info(topic, aggregated_info)
@@ -180,9 +154,6 @@ def research_topic(topic: str) -> str:
180
  return aggregated_info
181
 
182
  def fetch_wikipedia_summary(topic: str) -> str:
183
- """
184
- Fetch a quick Wikipedia summary of the topic via the official Wikipedia API.
185
- """
186
  print("[LOG] Fetching Wikipedia summary for:", topic)
187
  try:
188
  search_url = (
@@ -209,9 +180,6 @@ def fetch_wikipedia_summary(topic: str) -> str:
209
  return ""
210
 
211
  def fetch_rss_feed(feed_url: str) -> list:
212
- """
213
- Pulls RSS feed data from a given URL and returns items.
214
- """
215
  print("[LOG] Fetching RSS feed:", feed_url)
216
  try:
217
  resp = requests.get(feed_url)
@@ -226,10 +194,6 @@ def fetch_rss_feed(feed_url: str) -> list:
226
  return []
227
 
228
  def find_relevant_article(items, topic: str, min_match=2) -> tuple:
229
- """
230
- Check each article in the RSS feed for mention of the topic
231
- by counting the number of keyword matches.
232
- """
233
  print("[LOG] Finding relevant articles...")
234
  keywords = re.findall(r'\w+', topic.lower())
235
  for item in items:
@@ -244,9 +208,6 @@ def find_relevant_article(items, topic: str, min_match=2) -> tuple:
244
  return None, None, None
245
 
246
  def fetch_article_text(link: str) -> str:
247
- """
248
- Fetch the article text from the given link (first 5 paragraphs).
249
- """
250
  print("[LOG] Fetching article text from:", link)
251
  if not link:
252
  print("[LOG] No link provided for article text.")
@@ -275,8 +236,8 @@ def generate_script(
275
  sponsor_style: str = "Separate Break"
276
  ):
277
  """
278
- Sends the system_prompt plus input_text to the Groq LLM to generate a
279
- multi-speaker Dialogue in JSON, returning a Dialogue object.
280
  """
281
  print("[LOG] Generating script with tone:", tone, "and length:", target_length)
282
  groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
@@ -423,7 +384,6 @@ def transcribe_youtube_video(video_url: str) -> str:
423
  print(f"[DEBUG] Transcript Snippet: {snippet}")
424
 
425
  return transcript_as_text
426
-
427
  except Exception as e:
428
  print("[ERROR] RapidAPI transcription error:", e)
429
  raise ValueError(f"Error transcribing YouTube video via RapidAPI: {str(e)}")
@@ -431,8 +391,7 @@ def transcribe_youtube_video(video_url: str) -> str:
431
  def generate_audio_mp3(text: str, speaker: str) -> str:
432
  """
433
  Calls Deepgram TTS with the text, returning a path to a temp MP3 file.
434
- We also do some pre-processing for punctuation, abbreviations, numeric expansions,
435
- plus emotive expressions (ha, sigh, etc.).
436
  """
437
  try:
438
  print(f"[LOG] Generating audio for speaker: {speaker}")
@@ -443,7 +402,7 @@ def generate_audio_mp3(text: str, speaker: str) -> str:
443
  "model": "aura-asteria-en", # female by default
444
  }
445
  if speaker == "John":
446
- params["model"] = "aura-helios-en"
447
 
448
  headers = {
449
  "Accept": "audio/mpeg",
@@ -468,7 +427,6 @@ def generate_audio_mp3(text: str, speaker: str) -> str:
468
  mp3_file.write(chunk)
469
  mp3_path = mp3_file.name
470
 
471
- # Normalize volume
472
  audio_seg = AudioSegment.from_file(mp3_path, format="mp3")
473
  audio_seg = effects.normalize(audio_seg)
474
 
@@ -489,26 +447,25 @@ def transcribe_youtube_video_OLD_YTDLP(video_url: str) -> str:
489
  def _preprocess_text_for_tts(text: str, speaker: str) -> str:
490
  """
491
  1) "SaaS" => "sass"
492
- 2) Insert periods in uppercase abbreviations -> remove for TTS
493
- 3) Convert decimals like "3.14" -> "three point one four"
494
- 4) Convert pure integer numbers like "20" -> "twenty"
495
  5) Expand leftover all-caps
496
  6) Emotive placeholders for 'ha', 'haha', 'sigh', 'groan', etc.
497
- 7) If speaker == "John", we insert short breath pauses only after punctuation
498
- 8) Remove random fillers
499
  9) Capitalize sentence starts
500
  """
501
  # 1) "SaaS" => "sass"
502
  text = re.sub(r"\b(?i)SaaS\b", "sass", text)
503
 
504
- # 2) Insert periods in uppercase abbreviations, then remove them
505
  def insert_periods_for_abbrev(m):
506
  abbr = m.group(0)
507
  parted = ".".join(list(abbr)) + "."
508
  return parted
509
  text = re.sub(r"\b([A-Z0-9]{2,})\b", insert_periods_for_abbrev, text)
510
  text = re.sub(r"\.\.", ".", text)
511
-
512
  def remove_periods_for_tts(m):
513
  # "N.I.A." => "N I A"
514
  chunk = m.group(0)
@@ -527,13 +484,13 @@ def _preprocess_text_for_tts(text: str, speaker: str) -> str:
527
  return f"{whole_part} point {decimal_part}"
528
  text = re.sub(r"\b\d+\.\d+\b", convert_decimal, text)
529
 
530
- # 5) Convert pure integer => words
531
  def convert_int_to_words(m):
532
  num_str = m.group()
533
  return number_to_words(int(num_str))
534
  text = re.sub(r"\b\d+\b", convert_int_to_words, text)
535
 
536
- # 6) Expand leftover all-caps => "NASA" => "N A S A"
537
  def expand_abbreviations(m):
538
  abbrev = m.group()
539
  if abbrev.endswith('s') and abbrev[:-1].isupper():
@@ -549,11 +506,15 @@ def _preprocess_text_for_tts(text: str, speaker: str) -> str:
549
  return " ".join(list(abbrev))
550
  text = re.sub(r"\b[A-Z]{2,}s?\b", expand_abbreviations, text)
551
 
552
- # 7) If speaker == "John", insert short breath pauses after punctuation
 
 
 
 
 
553
  if speaker == "John":
554
- # Insert a short "..." after punctuation marks
555
- text = re.sub(r"([.,!?;:])", r"\1...", text)
556
- # Optionally remove random in-word pausing logic if you had it
557
 
558
  # 8) Remove random fillers
559
  text = re.sub(r"\b(uh|um|ah)\b", "", text, flags=re.IGNORECASE)
@@ -585,67 +546,58 @@ def _spell_digits(d: str) -> str:
585
 
586
  def number_to_words(n: int) -> str:
587
  """
588
- Enhanced integer-to-words up to ~999,999 or more.
589
- E.g., 10 -> 'ten', 4000 -> 'four thousand', 999999 -> 'nine hundred ninety nine thousand nine hundred ninety nine'
590
  """
591
  if n == 0:
592
  return "zero"
593
-
594
  if n < 0:
595
  return "minus " + number_to_words(-n)
596
 
597
- # Up to 999,999 or so. Extend if you need more.
598
  ones = ["","one","two","three","four","five","six","seven","eight","nine"]
599
  teens = ["ten","eleven","twelve","thirteen","fourteen","fifteen","sixteen","seventeen","eighteen","nineteen"]
600
  tens_words = ["","","twenty","thirty","forty","fifty","sixty","seventy","eighty","ninety"]
601
 
602
  def three_digits(x):
603
- """ Convert 0 <= x < 1000 to words """
604
- words = []
605
  hundreds = x // 100
606
- remainder = x % 100
607
-
608
  if hundreds > 0:
609
- words.append(ones[hundreds])
610
- words.append("hundred")
611
- if remainder > 0:
612
- words.append("and")
613
-
614
- if remainder < 10 and remainder > 0:
615
- words.append(ones[remainder])
616
- elif remainder >= 10 and remainder < 20:
617
- words.append(teens[remainder-10])
618
  else:
619
- t = remainder // 10
620
- o = remainder % 10
621
  if t > 1:
622
- words.append(tens_words[t])
623
  if o > 0:
624
- words.append(ones[o])
625
- return " ".join(w for w in words if w)
626
 
627
- # We'll chunk up to 999,999
628
  thousands = n // 1000
629
  remainder = n % 1000
630
 
631
- words_list = []
632
-
633
  if thousands > 0:
634
- words_list.append(three_digits(thousands))
635
- words_list.append("thousand")
636
-
637
  if remainder > 0:
638
- words_list.append(three_digits(remainder))
639
 
640
- final = " ".join(w for w in words_list if w).strip()
641
- return final or "zero"
642
 
643
  def mix_with_bg_music(spoken: AudioSegment, custom_music_path=None) -> AudioSegment:
644
  """
645
- Mixes 'spoken' with a default bg_music.mp3 or user-provided custom music:
646
- 1) Start with 2 seconds of music alone before speech begins.
647
- 2) Loop music if shorter than the final audio length.
648
- 3) Lower music volume so the speech is clear.
649
  """
650
  if custom_music_path:
651
  music_path = custom_music_path
@@ -669,12 +621,10 @@ def mix_with_bg_music(spoken: AudioSegment, custom_music_path=None) -> AudioSegm
669
  final_mix = looped_music.overlay(spoken, position=2000)
670
  return final_mix
671
 
672
- # This function is new for short Q&A calls
673
  def call_groq_api_for_qa(system_prompt: str) -> str:
674
  """
675
- A minimal placeholder for your short Q&A LLM call.
676
- Must return a JSON string, e.g.:
677
- {"speaker": "John", "text": "Short answer here"}
678
  """
679
  groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
680
  try:
 
18
  import random
19
 
20
  class DialogueItem(BaseModel):
21
+ speaker: Literal["Jane", "John"]
22
+ display_speaker: str = "Jane"
23
  text: str
24
 
25
  class Dialogue(BaseModel):
26
  dialogue: List[DialogueItem]
27
 
28
+ # Not used for YouTube, but for local if needed
29
  asr_pipeline = pipeline(
30
  "automatic-speech-recognition",
31
  model="openai/whisper-tiny.en",
 
33
  )
34
 
35
  def truncate_text(text, max_tokens=2048):
 
 
 
 
36
  print("[LOG] Truncating text if needed.")
37
  tokenizer = tiktoken.get_encoding("cl100k_base")
38
  tokens = tokenizer.encode(text)
 
42
  return text
43
 
44
  def extract_text_from_url(url):
 
 
 
 
45
  print("[LOG] Extracting text from URL:", url)
46
  try:
47
  headers = {
 
66
  return ""
67
 
68
  def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment:
 
 
 
 
69
  print(f"[LOG] Shifting pitch by {semitones} semitones.")
70
  new_sample_rate = int(audio.frame_rate * (2.0 ** (semitones / 12.0)))
71
  shifted_audio = audio._spawn(audio.raw_data, overrides={'frame_rate': new_sample_rate})
72
  return shifted_audio.set_frame_rate(audio.frame_rate)
73
 
74
  def is_sufficient(text: str, min_word_count: int = 500) -> bool:
 
 
 
 
75
  word_count = len(text.split())
76
  print(f"[DEBUG] Aggregated word count: {word_count}")
77
  return word_count >= min_word_count
78
 
79
  def query_llm_for_additional_info(topic: str, existing_text: str) -> str:
 
 
 
 
80
  print("[LOG] Querying LLM for additional information.")
81
  system_prompt = (
82
  "You are an AI assistant with extensive knowledge up to 2023-10. "
 
102
  return additional_info
103
 
104
  def research_topic(topic: str) -> str:
 
 
 
 
105
  sources = {
106
  "BBC": "https://feeds.bbci.co.uk/news/rss.xml",
107
  "CNN": "http://rss.cnn.com/rss/edition.rss",
 
120
  if wiki_summary:
121
  summary_parts.append(f"From Wikipedia: {wiki_summary}")
122
 
 
123
  for name, feed_url in sources.items():
124
  try:
125
  items = fetch_rss_feed(feed_url)
 
140
  print("[DEBUG] Aggregated info from primary sources:")
141
  print(aggregated_info)
142
 
 
143
  if not is_sufficient(aggregated_info):
144
  print("[LOG] Insufficient info from primary sources. Fallback to LLM.")
145
  additional_info = query_llm_for_additional_info(topic, aggregated_info)
 
154
  return aggregated_info
155
 
156
  def fetch_wikipedia_summary(topic: str) -> str:
 
 
 
157
  print("[LOG] Fetching Wikipedia summary for:", topic)
158
  try:
159
  search_url = (
 
180
  return ""
181
 
182
  def fetch_rss_feed(feed_url: str) -> list:
 
 
 
183
  print("[LOG] Fetching RSS feed:", feed_url)
184
  try:
185
  resp = requests.get(feed_url)
 
194
  return []
195
 
196
  def find_relevant_article(items, topic: str, min_match=2) -> tuple:
 
 
 
 
197
  print("[LOG] Finding relevant articles...")
198
  keywords = re.findall(r'\w+', topic.lower())
199
  for item in items:
 
208
  return None, None, None
209
 
210
  def fetch_article_text(link: str) -> str:
 
 
 
211
  print("[LOG] Fetching article text from:", link)
212
  if not link:
213
  print("[LOG] No link provided for article text.")
 
236
  sponsor_style: str = "Separate Break"
237
  ):
238
  """
239
+ If sponsor content is empty, we won't have sponsor instructions appended in app.py's prompt.
240
+ So the LLM should not generate sponsor segments.
241
  """
242
  print("[LOG] Generating script with tone:", tone, "and length:", target_length)
243
  groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
 
384
  print(f"[DEBUG] Transcript Snippet: {snippet}")
385
 
386
  return transcript_as_text
 
387
  except Exception as e:
388
  print("[ERROR] RapidAPI transcription error:", e)
389
  raise ValueError(f"Error transcribing YouTube video via RapidAPI: {str(e)}")
 
391
  def generate_audio_mp3(text: str, speaker: str) -> str:
392
  """
393
  Calls Deepgram TTS with the text, returning a path to a temp MP3 file.
394
+ Then we do normal volume normalization, etc.
 
395
  """
396
  try:
397
  print(f"[LOG] Generating audio for speaker: {speaker}")
 
402
  "model": "aura-asteria-en", # female by default
403
  }
404
  if speaker == "John":
405
+ params["model"] = "aura-zeus-en"
406
 
407
  headers = {
408
  "Accept": "audio/mpeg",
 
427
  mp3_file.write(chunk)
428
  mp3_path = mp3_file.name
429
 
 
430
  audio_seg = AudioSegment.from_file(mp3_path, format="mp3")
431
  audio_seg = effects.normalize(audio_seg)
432
 
 
447
  def _preprocess_text_for_tts(text: str, speaker: str) -> str:
448
  """
449
  1) "SaaS" => "sass"
450
+ 2) Insert periods for uppercase abbreviations -> remove for TTS (N.I.A. => N I A)
451
+ 3) Convert decimals (3.14 => 'three point one four')
452
+ 4) Convert integers (10 => 'ten', 4000 => 'four thousand')
453
  5) Expand leftover all-caps
454
  6) Emotive placeholders for 'ha', 'haha', 'sigh', 'groan', etc.
455
+ 7) If speaker == "John", insert short breath "..." after punctuation (not random mid-word)
456
+ 8) Remove random fillers (uh, um)
457
  9) Capitalize sentence starts
458
  """
459
  # 1) "SaaS" => "sass"
460
  text = re.sub(r"\b(?i)SaaS\b", "sass", text)
461
 
462
+ # 2) Insert periods for uppercase abbreviations => remove them
463
  def insert_periods_for_abbrev(m):
464
  abbr = m.group(0)
465
  parted = ".".join(list(abbr)) + "."
466
  return parted
467
  text = re.sub(r"\b([A-Z0-9]{2,})\b", insert_periods_for_abbrev, text)
468
  text = re.sub(r"\.\.", ".", text)
 
469
  def remove_periods_for_tts(m):
470
  # "N.I.A." => "N I A"
471
  chunk = m.group(0)
 
484
  return f"{whole_part} point {decimal_part}"
485
  text = re.sub(r"\b\d+\.\d+\b", convert_decimal, text)
486
 
487
+ # Convert pure integers => words
488
  def convert_int_to_words(m):
489
  num_str = m.group()
490
  return number_to_words(int(num_str))
491
  text = re.sub(r"\b\d+\b", convert_int_to_words, text)
492
 
493
+ # 5) Expand leftover all-caps => "NASA" => "N A S A"
494
  def expand_abbreviations(m):
495
  abbrev = m.group()
496
  if abbrev.endswith('s') and abbrev[:-1].isupper():
 
506
  return " ".join(list(abbrev))
507
  text = re.sub(r"\b[A-Z]{2,}s?\b", expand_abbreviations, text)
508
 
509
+ # 6) Emotive placeholders
510
+ text = re.sub(r"\b(ha(ha)?|heh|lol)\b", "(* laughs *)", text, flags=re.IGNORECASE)
511
+ text = re.sub(r"\bsigh\b", "(* sighs *)", text, flags=re.IGNORECASE)
512
+ text = re.sub(r"\b(groan|moan)\b", "(* groans *)", text, flags=re.IGNORECASE)
513
+
514
+ # 7) If speaker == "John", place short "..." after punctuation only
515
  if speaker == "John":
516
+ # Insert a short "..." after . , ! ? ; :
517
+ text = re.sub(r"([.,!?;:])(\s|$)", r"\1...\2", text)
 
518
 
519
  # 8) Remove random fillers
520
  text = re.sub(r"\b(uh|um|ah)\b", "", text, flags=re.IGNORECASE)
 
546
 
547
  def number_to_words(n: int) -> str:
548
  """
549
+ Enhanced integer-to-words up to 999,999 so '10' => 'ten', '4000' => 'four thousand'.
 
550
  """
551
  if n == 0:
552
  return "zero"
 
553
  if n < 0:
554
  return "minus " + number_to_words(-n)
555
 
 
556
  ones = ["","one","two","three","four","five","six","seven","eight","nine"]
557
  teens = ["ten","eleven","twelve","thirteen","fourteen","fifteen","sixteen","seventeen","eighteen","nineteen"]
558
  tens_words = ["","","twenty","thirty","forty","fifty","sixty","seventy","eighty","ninety"]
559
 
560
  def three_digits(x):
561
+ w = []
 
562
  hundreds = x // 100
563
+ rem = x % 100
 
564
  if hundreds > 0:
565
+ w.append(ones[hundreds])
566
+ w.append("hundred")
567
+ if rem > 0:
568
+ w.append("and")
569
+ if rem < 10 and rem > 0:
570
+ w.append(ones[rem])
571
+ elif rem >= 10 and rem < 20:
572
+ w.append(teens[rem - 10])
 
573
  else:
574
+ t = rem // 10
575
+ o = rem % 10
576
  if t > 1:
577
+ w.append(tens_words[t])
578
  if o > 0:
579
+ w.append(ones[o])
580
+ return " ".join(i for i in w if i)
581
 
 
582
  thousands = n // 1000
583
  remainder = n % 1000
584
 
585
+ parts = []
 
586
  if thousands > 0:
587
+ parts.append(three_digits(thousands))
588
+ parts.append("thousand")
 
589
  if remainder > 0:
590
+ parts.append(three_digits(remainder))
591
 
592
+ out = " ".join(i for i in parts if i).strip()
593
+ return out or "zero"
594
 
595
  def mix_with_bg_music(spoken: AudioSegment, custom_music_path=None) -> AudioSegment:
596
  """
597
+ Mixes 'spoken' with bg_music.mp3 or custom music:
598
+ - 2s lead-in
599
+ - Loop if shorter
600
+ - Lower volume
601
  """
602
  if custom_music_path:
603
  music_path = custom_music_path
 
621
  final_mix = looped_music.overlay(spoken, position=2000)
622
  return final_mix
623
 
 
624
  def call_groq_api_for_qa(system_prompt: str) -> str:
625
  """
626
+ Minimal function for short Q&A calls. Must return JSON:
627
+ { "speaker": "John", "text": "Short answer" }
 
628
  """
629
  groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
630
  try: