siddhartharyaai commited on
Commit
e62d0b2
·
verified ·
1 Parent(s): e7283ef

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +102 -136
utils.py CHANGED
@@ -12,9 +12,9 @@ from pydub import AudioSegment, effects
12
  from transformers import pipeline
13
  import yt_dlp
14
  import tiktoken
15
- from groq import Groq # Ensure Groq client is imported
16
  import numpy as np
17
- import torch # Added to check CUDA availability
18
  import random
19
 
20
  class DialogueItem(BaseModel):
@@ -56,7 +56,7 @@ def extract_text_from_url(url):
56
  def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment:
57
  """
58
  Shifts the pitch of an AudioSegment by a given number of semitones.
59
- Positive semitones shift the pitch up, negative shift it down.
60
  """
61
  print(f"[LOG] Shifting pitch by {semitones} semitones.")
62
  new_sample_rate = int(audio.frame_rate * (2.0 ** (semitones / 12.0)))
@@ -83,7 +83,6 @@ def query_llm_for_additional_info(topic: str, existing_text: str) -> str:
83
  f"Existing Information: {existing_text}\n\n"
84
  "Please add more insightful details, facts, and perspectives to enhance the understanding of the topic."
85
  )
86
-
87
  groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
88
  try:
89
  response = groq_client.chat.completions.create(
@@ -95,14 +94,13 @@ def query_llm_for_additional_info(topic: str, existing_text: str) -> str:
95
  except Exception as e:
96
  print("[ERROR] Groq API error during fallback:", e)
97
  return ""
98
-
99
  additional_info = response.choices[0].message.content.strip()
100
  print("[DEBUG] Additional information from LLM:")
101
  print(additional_info)
102
  return additional_info
103
 
104
  def research_topic(topic: str) -> str:
105
- # Sources:
106
  sources = {
107
  "BBC": "https://feeds.bbci.co.uk/news/rss.xml",
108
  "CNN": "http://rss.cnn.com/rss/edition.rss",
@@ -116,6 +114,7 @@ def research_topic(topic: str) -> str:
116
 
117
  summary_parts = []
118
 
 
119
  wiki_summary = fetch_wikipedia_summary(topic)
120
  if wiki_summary:
121
  summary_parts.append(f"From Wikipedia: {wiki_summary}")
@@ -137,7 +136,7 @@ def research_topic(topic: str) -> str:
137
  continue
138
 
139
  aggregated_info = " ".join(summary_parts)
140
- print("[DEBUG] Aggregated information from primary sources.")
141
  print(aggregated_info)
142
 
143
  if not is_sufficient(aggregated_info):
@@ -159,7 +158,7 @@ def fetch_wikipedia_summary(topic: str) -> str:
159
  search_url = f"https://en.wikipedia.org/w/api.php?action=opensearch&search={requests.utils.quote(topic)}&limit=1&namespace=0&format=json"
160
  resp = requests.get(search_url)
161
  if resp.status_code != 200:
162
- print(f"[ERROR] Failed to fetch Wikipedia search for topic: {topic}")
163
  return ""
164
  data = resp.json()
165
  if len(data) > 1 and data[1]:
@@ -169,7 +168,7 @@ def fetch_wikipedia_summary(topic: str) -> str:
169
  if s_resp.status_code == 200:
170
  s_data = s_resp.json()
171
  if "extract" in s_data:
172
- print("[LOG] Wikipedia summary fetched.")
173
  return s_data["extract"]
174
  return ""
175
  except Exception as e:
@@ -181,17 +180,19 @@ def fetch_rss_feed(feed_url: str) -> list:
181
  try:
182
  resp = requests.get(feed_url)
183
  if resp.status_code != 200:
184
- print(f"[ERROR] Failed to fetch RSS feed {feed_url}")
185
  return []
186
  soup = BeautifulSoup(resp.content, "html.parser")
187
  items = soup.find_all("item")
188
- print(f"[LOG] Number of items: {len(items)} from {feed_url}")
189
  return items
190
  except Exception as e:
191
  print(f"[ERROR] Exception fetching RSS feed {feed_url}: {e}")
192
  return []
193
 
194
  def find_relevant_article(items, topic: str, min_match=2) -> tuple:
 
 
 
195
  print("[LOG] Finding relevant articles...")
196
  keywords = re.findall(r'\w+', topic.lower())
197
  for item in items:
@@ -201,12 +202,12 @@ def find_relevant_article(items, topic: str, min_match=2) -> tuple:
201
  matches = sum(1 for kw in keywords if kw in text)
202
  if matches >= min_match:
203
  link = item.find("link").get_text().strip() if item.find("link") else ""
204
- print(f"[LOG] Relevant article: {title}")
205
  return title, description, link
206
  return None, None, None
207
 
208
  def fetch_article_text(link: str) -> str:
209
- print("[LOG] Fetching article text:", link)
210
  if not link:
211
  return ""
212
  try:
@@ -262,9 +263,6 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
262
  "}"
263
  )
264
 
265
- print("[LOG] Sending prompt to Groq:")
266
- print(prompt)
267
-
268
  try:
269
  response = groq_client.chat.completions.create(
270
  messages=[{"role": "system", "content": prompt}],
@@ -285,152 +283,120 @@ def generate_script(system_prompt: str, input_text: str, tone: str, target_lengt
285
  data = json.loads(json_str)
286
  return Dialogue(**data)
287
 
288
- # -------------------------------------------------------------
289
- # Helper function: Insert random filler words, extra punctuation
290
- # BUT we'll handle that chunk by chunk (see below).
291
- # -------------------------------------------------------------
292
- def _make_text_sound_more_human(text: str) -> str:
293
  """
294
- Inserts small filler words and modifies punctuation
295
- for more natural-sounding speech.
 
296
  """
297
- fillers = ["uh", "um", "ah", "hmm", "you know", "well", "I mean", "like"]
298
- # Insert filler sometimes at start or middle:
299
- if text and random.random() < 0.4:
300
- filler = random.choice(fillers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  if random.random() < 0.5:
302
- text = f"{filler}, {text}"
303
  else:
304
- words = text.split()
305
- mid = len(words) // 2
306
- text = " ".join(words[:mid] + [f"{filler},"] + words[mid:])
307
-
308
- # Possibly turn periods into "..." to force a pause
309
- text = re.sub(r'\.(\s|$)', lambda m: "..." + m.group(1), text)
310
-
311
- # Possibly turn "?" into "?!" or "!!" for exclamation
312
- if random.random() < 0.2:
313
- text = text.replace("?", "?!")
314
- if random.random() < 0.2:
315
- text = text.replace("!", "!!")
316
 
317
  return text.strip()
318
 
319
- def _split_into_sentences_and_phrases(text: str):
320
  """
321
- Splits the text into smaller chunks so each chunk can be TTS-ed
322
- individually for better pacing. We'll look for ., !, or ?
323
- as sentence boundaries. Also splits by commas for short phrases.
324
  """
325
- # Split by sentence enders with a lookbehind to keep delimiters separate.
326
- # We can then further split by commas if the sentence is long.
327
- # E.g. "Hello there. This is a test?" => ["Hello there.", "This is a test?"]
328
- # Then if "Hello there." is too big, we might split by commas as well.
329
- boundaries = re.split(r'([.?!])', text)
330
-
331
- # Rebuild into "sentence + punctuation" pairs
332
- phrases = []
333
- for i in range(0, len(boundaries), 2):
334
- if i + 1 < len(boundaries):
335
- chunk = (boundaries[i] + boundaries[i+1]).strip()
336
- else:
337
- chunk = boundaries[i].strip()
338
- if chunk:
339
- # Now optionally split chunk by commas if it's too big
340
- subparts = chunk.split(',')
341
- # If there's more than 1 subpart, rejoin them carefully so each subpart can be TTS-ed on its own
342
- for idx, sp in enumerate(subparts):
343
- part = sp.strip()
344
- if part:
345
- # Re-add comma except on the last one
346
- if idx < len(subparts) - 1:
347
- part += ","
348
- phrases.append(part)
349
- return phrases
350
 
351
  def generate_audio_mp3(text: str, speaker: str) -> str:
 
 
 
352
  try:
353
  print(f"[LOG] Generating audio for speaker: {speaker}")
354
 
355
- # Step 1: Split text into small pieces (phrases, sentences)
356
- fragments = _split_into_sentences_and_phrases(text)
357
-
358
- # Step 2: For each fragment, transform it to be more human-like, TTS it, then combine
359
- all_segments = []
360
- for frag in fragments:
361
- if not frag.strip():
362
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
- # Make the chunk more "human"
365
- human_chunk = _make_text_sound_more_human(frag)
 
366
 
367
- # TTS this chunk
368
- mp3_path = _tts_chunk(human_chunk, speaker)
369
- seg = AudioSegment.from_file(mp3_path, format="mp3")
370
- seg = effects.normalize(seg)
371
- all_segments.append(seg)
372
 
373
- # Clean up
374
- if os.path.exists(mp3_path):
375
- os.remove(mp3_path)
376
 
377
- if not all_segments:
378
- raise ValueError("No audio segments produced.")
379
 
380
- # Step 3: Combine segments with a short silence between
381
- final_audio = all_segments[0]
382
- short_silence = AudioSegment.silent(duration=300) # 300ms silence
383
- for seg in all_segments[1:]:
384
- final_audio = final_audio + short_silence + seg
385
 
386
- # Step 4: Save combined
387
- final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
388
- final_audio.export(final_mp3_path, format="mp3")
389
- print("[LOG] Combined audio saved at:", final_mp3_path)
390
  return final_mp3_path
391
-
392
  except Exception as e:
393
  print("[ERROR] Error generating audio:", e)
394
  raise ValueError(f"Error generating audio: {str(e)}")
395
 
396
- def _tts_chunk(text: str, speaker: str) -> str:
397
  """
398
- Helper function to do TTS on a single chunk of text
399
- (so we can call multiple times).
400
  """
401
- deepgram_api_url = "https://api.deepgram.com/v1/speak"
402
- params = {
403
- "model": "aura-asteria-en", # default female
404
- }
405
- if speaker == "John":
406
- params["model"] = "aura-perseus-en"
407
-
408
- headers = {
409
- "Accept": "audio/mpeg",
410
- "Content-Type": "application/json",
411
- "Authorization": f"Token {os.environ.get('DEEPGRAM_API_KEY')}"
412
- }
413
- body = {
414
- "text": text
415
- }
416
-
417
- response = requests.post(deepgram_api_url, params=params, headers=headers, json=body, stream=True)
418
- if response.status_code != 200:
419
- raise ValueError(f"Deepgram TTS error: {response.status_code}, {response.text}")
420
-
421
- content_type = response.headers.get('Content-Type', '')
422
- if 'audio/mpeg' not in content_type:
423
- raise ValueError("Unexpected Content-Type from Deepgram.")
424
-
425
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as mp3_file:
426
- for chunk in response.iter_content(chunk_size=8192):
427
- if chunk:
428
- mp3_file.write(chunk)
429
- mp3_path = mp3_file.name
430
-
431
- return mp3_path
432
-
433
- def transcribe_youtube_video(video_url: str) -> str:
434
  print("[LOG] Transcribing YouTube video:", video_url)
435
  fd, audio_file = tempfile.mkstemp(suffix=".wav")
436
  os.close(fd)
@@ -464,4 +430,4 @@ def transcribe_youtube_video(video_url: str) -> str:
464
  finally:
465
  if os.path.exists(audio_file):
466
  os.remove(audio_file)
467
- print(f"[LOG] Removed temp audio file: {audio_file}")
 
12
  from transformers import pipeline
13
  import yt_dlp
14
  import tiktoken
15
+ from groq import Groq
16
  import numpy as np
17
+ import torch
18
  import random
19
 
20
  class DialogueItem(BaseModel):
 
56
  def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment:
57
  """
58
  Shifts the pitch of an AudioSegment by a given number of semitones.
59
+ Positive semitones shift the pitch up, negative shifts it down.
60
  """
61
  print(f"[LOG] Shifting pitch by {semitones} semitones.")
62
  new_sample_rate = int(audio.frame_rate * (2.0 ** (semitones / 12.0)))
 
83
  f"Existing Information: {existing_text}\n\n"
84
  "Please add more insightful details, facts, and perspectives to enhance the understanding of the topic."
85
  )
 
86
  groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
87
  try:
88
  response = groq_client.chat.completions.create(
 
94
  except Exception as e:
95
  print("[ERROR] Groq API error during fallback:", e)
96
  return ""
 
97
  additional_info = response.choices[0].message.content.strip()
98
  print("[DEBUG] Additional information from LLM:")
99
  print(additional_info)
100
  return additional_info
101
 
102
  def research_topic(topic: str) -> str:
103
+ # News sources
104
  sources = {
105
  "BBC": "https://feeds.bbci.co.uk/news/rss.xml",
106
  "CNN": "http://rss.cnn.com/rss/edition.rss",
 
114
 
115
  summary_parts = []
116
 
117
+ # Wikipedia summary
118
  wiki_summary = fetch_wikipedia_summary(topic)
119
  if wiki_summary:
120
  summary_parts.append(f"From Wikipedia: {wiki_summary}")
 
136
  continue
137
 
138
  aggregated_info = " ".join(summary_parts)
139
+ print("[DEBUG] Aggregated info from primary sources:")
140
  print(aggregated_info)
141
 
142
  if not is_sufficient(aggregated_info):
 
158
  search_url = f"https://en.wikipedia.org/w/api.php?action=opensearch&search={requests.utils.quote(topic)}&limit=1&namespace=0&format=json"
159
  resp = requests.get(search_url)
160
  if resp.status_code != 200:
161
+ print(f"[ERROR] Failed to fetch Wikipedia search results for topic: {topic}")
162
  return ""
163
  data = resp.json()
164
  if len(data) > 1 and data[1]:
 
168
  if s_resp.status_code == 200:
169
  s_data = s_resp.json()
170
  if "extract" in s_data:
171
+ print("[LOG] Wikipedia summary fetched successfully.")
172
  return s_data["extract"]
173
  return ""
174
  except Exception as e:
 
180
  try:
181
  resp = requests.get(feed_url)
182
  if resp.status_code != 200:
183
+ print(f"[ERROR] Failed to fetch RSS feed: {feed_url}")
184
  return []
185
  soup = BeautifulSoup(resp.content, "html.parser")
186
  items = soup.find_all("item")
 
187
  return items
188
  except Exception as e:
189
  print(f"[ERROR] Exception fetching RSS feed {feed_url}: {e}")
190
  return []
191
 
192
  def find_relevant_article(items, topic: str, min_match=2) -> tuple:
193
+ """
194
+ Searches for relevant articles based on topic keywords.
195
+ """
196
  print("[LOG] Finding relevant articles...")
197
  keywords = re.findall(r'\w+', topic.lower())
198
  for item in items:
 
202
  matches = sum(1 for kw in keywords if kw in text)
203
  if matches >= min_match:
204
  link = item.find("link").get_text().strip() if item.find("link") else ""
205
+ print(f"[LOG] Relevant article found: {title}")
206
  return title, description, link
207
  return None, None, None
208
 
209
  def fetch_article_text(link: str) -> str:
210
+ print("[LOG] Fetching article text from:", link)
211
  if not link:
212
  return ""
213
  try:
 
263
  "}"
264
  )
265
 
 
 
 
266
  try:
267
  response = groq_client.chat.completions.create(
268
  messages=[{"role": "system", "content": prompt}],
 
283
  data = json.loads(json_str)
284
  return Dialogue(**data)
285
 
286
+ # --------------------------------------------------------------
287
+ # TTS Preprocessing to handle decimals, hyphens, and selective fillers
288
+ # --------------------------------------------------------------
289
+ def _preprocess_text_for_tts(text: str) -> str:
 
290
  """
291
+ 1) Convert decimals to spelled-out words ("3.14" -> "three point one four").
292
+ 2) Replace hyphens with spaces.
293
+ 3) Insert filler words only in certain contexts (like "I think", or after '?').
294
  """
295
+ # 1) Convert decimals
296
+ def convert_decimal(m):
297
+ number_str = m.group() # e.g. "3.14"
298
+ parts = number_str.split('.')
299
+ whole_part = _spell_digits(parts[0]) # "three"
300
+ decimal_part = " ".join(_spell_digits(d) for d in parts[1])
301
+ return f"{whole_part} point {decimal_part}"
302
+
303
+ text = re.sub(r"\d+\.\d+", convert_decimal, text)
304
+
305
+ # 2) Hyphens -> spaces
306
+ text = re.sub(r"-", " ", text)
307
+
308
+ # 3) Targeted filler insertion
309
+ # a) Insert "uh" after "I think" or "I'm not sure", etc. (very naive approach)
310
+ text = re.sub(
311
+ r"(I think|I'm not sure|I guess)([,.]?\s)",
312
+ r"\1, uh,\2",
313
+ text,
314
+ flags=re.IGNORECASE
315
+ )
316
+
317
+ # b) If there's a "?" then sometimes insert "um," right after it
318
+ text = text.replace("?", "?<QMARK>")
319
+ def insert_filler_qmark(m):
320
  if random.random() < 0.5:
321
+ return "? um,"
322
  else:
323
+ return "?"
324
+ text = re.sub(r"\?<QMARK>", insert_filler_qmark, text)
 
 
 
 
 
 
 
 
 
 
325
 
326
  return text.strip()
327
 
328
+ def _spell_digits(d: str) -> str:
329
  """
330
+ Convert each digit '3' -> 'three', '5' -> 'five', etc.
 
 
331
  """
332
+ digit_map = {
333
+ '0': 'zero', '1': 'one', '2': 'two', '3': 'three',
334
+ '4': 'four','5': 'five','6': 'six','7': 'seven',
335
+ '8': 'eight','9': 'nine'
336
+ }
337
+ return " ".join(digit_map[ch] for ch in d if ch in digit_map)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  def generate_audio_mp3(text: str, speaker: str) -> str:
340
+ """
341
+ Main TTS function, calls Deepgram with preprocessed text.
342
+ """
343
  try:
344
  print(f"[LOG] Generating audio for speaker: {speaker}")
345
 
346
+ # Preprocess text (decimal/hyphen/fillers)
347
+ processed_text = _preprocess_text_for_tts(text)
348
+
349
+ # Define Deepgram API endpoint
350
+ deepgram_api_url = "https://api.deepgram.com/v1/speak"
351
+ params = {
352
+ "model": "aura-asteria-en", # default female
353
+ }
354
+ if speaker == "John":
355
+ params["model"] = "aura-perseus-en"
356
+
357
+ headers = {
358
+ "Accept": "audio/mpeg",
359
+ "Content-Type": "application/json",
360
+ "Authorization": f"Token {os.environ.get('DEEPGRAM_API_KEY')}"
361
+ }
362
+ body = {
363
+ "text": processed_text
364
+ }
365
+
366
+ print("[LOG] Sending TTS request to Deepgram...")
367
+ response = requests.post(deepgram_api_url, params=params, headers=headers, json=body, stream=True)
368
+ if response.status_code != 200:
369
+ raise ValueError(f"Deepgram TTS error: {response.status_code}, {response.text}")
370
 
371
+ content_type = response.headers.get('Content-Type', '')
372
+ if 'audio/mpeg' not in content_type:
373
+ raise ValueError("Unexpected Content-Type from Deepgram.")
374
 
375
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as mp3_file:
376
+ for chunk in response.iter_content(chunk_size=8192):
377
+ if chunk:
378
+ mp3_file.write(chunk)
379
+ mp3_path = mp3_file.name
380
 
381
+ # Normalize volume
382
+ audio_seg = AudioSegment.from_file(mp3_path, format="mp3")
383
+ audio_seg = effects.normalize(audio_seg)
384
 
385
+ final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
386
+ audio_seg.export(final_mp3_path, format="mp3")
387
 
388
+ if os.path.exists(mp3_path):
389
+ os.remove(mp3_path)
 
 
 
390
 
 
 
 
 
391
  return final_mp3_path
 
392
  except Exception as e:
393
  print("[ERROR] Error generating audio:", e)
394
  raise ValueError(f"Error generating audio: {str(e)}")
395
 
396
+ def transcribe_youtube_video(video_url: str) -> str:
397
  """
398
+ Downloads and transcribes the audio from a YouTube video using Whisper (pipeline).
 
399
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  print("[LOG] Transcribing YouTube video:", video_url)
401
  fd, audio_file = tempfile.mkstemp(suffix=".wav")
402
  os.close(fd)
 
430
  finally:
431
  if os.path.exists(audio_file):
432
  os.remove(audio_file)
433
+ print(f"[LOG] Removed temporary audio file: {audio_file}")