MyPod_10 / utils.py
siddhartharyaai's picture
Update utils.py
84a3c5a verified
raw
history blame
24.4 kB
# utils.py
import os
import re
import json
import requests
import tempfile
from bs4 import BeautifulSoup
from typing import List, Literal
from pydantic import BaseModel
from pydub import AudioSegment, effects
from transformers import pipeline
import yt_dlp
import tiktoken
from groq import Groq
import numpy as np
import torch
import random
# ---------------------------------------------------------------------
# Updated: DialogueItem now has an extra field `display_speaker`
# ---------------------------------------------------------------------
class DialogueItem(BaseModel):
speaker: Literal["Jane", "John"] # Used internally for TTS voice
display_speaker: str = "Jane" # The name shown in the user-facing transcript
text: str
class Dialogue(BaseModel):
dialogue: List[DialogueItem]
# Initialize Whisper ASR pipeline (unused for YouTube now, but still available for local audio)
asr_pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-tiny.en",
device=0 if torch.cuda.is_available() else -1
)
def truncate_text(text, max_tokens=2048):
"""
If the text exceeds the max token limit (approx. 2,048), truncate it
to avoid exceeding the model's context window.
"""
print("[LOG] Truncating text if needed.")
tokenizer = tiktoken.get_encoding("cl100k_base")
tokens = tokenizer.encode(text)
if len(tokens) > max_tokens:
print("[LOG] Text too long, truncating.")
return tokenizer.decode(tokens[:max_tokens])
return text
def extract_text_from_url(url):
"""
Fetches and extracts readable text from a given URL
(stripping out scripts, styles, etc.).
"""
print("[LOG] Extracting text from URL:", url)
try:
headers = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/115.0.0.0 Safari/537.36"
)
}
response = requests.get(url, headers=headers)
if response.status_code != 200:
print(f"[ERROR] Failed to fetch URL: {url} with status code {response.status_code}")
return ""
soup = BeautifulSoup(response.text, 'html.parser')
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text(separator=' ')
print("[LOG] Text extraction from URL successful.")
return text
except Exception as e:
print(f"[ERROR] Exception during text extraction from URL: {e}")
return ""
def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment:
"""
Shifts the pitch of an AudioSegment by a given number of semitones.
Positive semitones shift the pitch up, negative shifts it down.
"""
print(f"[LOG] Shifting pitch by {semitones} semitones.")
new_sample_rate = int(audio.frame_rate * (2.0 ** (semitones / 12.0)))
shifted_audio = audio._spawn(audio.raw_data, overrides={'frame_rate': new_sample_rate})
return shifted_audio.set_frame_rate(audio.frame_rate)
def is_sufficient(text: str, min_word_count: int = 500) -> bool:
"""
Checks if the fetched text meets our sufficiency criteria
(e.g., at least 500 words).
"""
word_count = len(text.split())
print(f"[DEBUG] Aggregated word count: {word_count}")
return word_count >= min_word_count
def query_llm_for_additional_info(topic: str, existing_text: str) -> str:
"""
Queries the Groq API to retrieve more info from the LLM's knowledge base.
Appends it to our aggregated info if found.
"""
print("[LOG] Querying LLM for additional information.")
system_prompt = (
"You are an AI assistant with extensive knowledge up to 2023-10. "
"Provide additional relevant information on the following topic based on your knowledge base.\n\n"
f"Topic: {topic}\n\n"
f"Existing Information: {existing_text}\n\n"
"Please add more insightful details, facts, and perspectives to enhance the understanding of the topic."
)
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
try:
response = groq_client.chat.completions.create(
messages=[{"role": "system", "content": system_prompt}],
model="llama-3.3-70b-versatile",
max_tokens=1024,
temperature=0.7
)
except Exception as e:
print("[ERROR] Groq API error during fallback:", e)
return ""
additional_info = response.choices[0].message.content.strip()
print("[DEBUG] Additional information from LLM:")
print(additional_info)
return additional_info
def research_topic(topic: str) -> str:
"""
Gathers info from various RSS feeds and Wikipedia. If needed, queries the LLM
for more data if the aggregated text is insufficient.
"""
sources = {
"BBC": "https://feeds.bbci.co.uk/news/rss.xml",
"CNN": "http://rss.cnn.com/rss/edition.rss",
"Associated Press": "https://apnews.com/apf-topnews",
"NDTV": "https://www.ndtv.com/rss/top-stories",
"Times of India": "https://timesofindia.indiatimes.com/rssfeeds/296589292.cms",
"The Hindu": "https://www.thehindu.com/news/national/kerala/rssfeed.xml",
"Economic Times": "https://economictimes.indiatimes.com/rssfeeds/1977021501.cms",
"Google News - Custom": f"https://news.google.com/rss/search?q={requests.utils.quote(topic)}&hl=en-IN&gl=IN&ceid=IN:en",
}
summary_parts = []
# Wikipedia summary
wiki_summary = fetch_wikipedia_summary(topic)
if wiki_summary:
summary_parts.append(f"From Wikipedia: {wiki_summary}")
# For each RSS
for name, url in sources.items():
try:
items = fetch_rss_feed(url)
if not items:
continue
title, desc, link = find_relevant_article(items, topic, min_match=2)
if link:
article_text = fetch_article_text(link)
if article_text:
summary_parts.append(f"From {name}: {article_text}")
else:
summary_parts.append(f"From {name}: {title} - {desc}")
except Exception as e:
print(f"[ERROR] Error fetching from {name} RSS feed:", e)
continue
aggregated_info = " ".join(summary_parts)
print("[DEBUG] Aggregated info from primary sources:")
print(aggregated_info)
# If not enough data, fallback to LLM
if not is_sufficient(aggregated_info):
print("[LOG] Insufficient info from primary sources. Fallback to LLM.")
additional_info = query_llm_for_additional_info(topic, aggregated_info)
if additional_info:
aggregated_info += " " + additional_info
else:
print("[ERROR] Failed to retrieve additional info from LLM.")
if not aggregated_info:
return f"Sorry, I couldn't find recent information on '{topic}'."
return aggregated_info
def fetch_wikipedia_summary(topic: str) -> str:
"""
Fetch a quick Wikipedia summary of the topic via the official Wikipedia API.
"""
print("[LOG] Fetching Wikipedia summary for:", topic)
try:
search_url = (
f"https://en.wikipedia.org/w/api.php?action=opensearch&search={requests.utils.quote(topic)}"
"&limit=1&namespace=0&format=json"
)
resp = requests.get(search_url)
if resp.status_code != 200:
print(f"[ERROR] Failed to fetch Wikipedia search results for {topic}")
return ""
data = resp.json()
if len(data) > 1 and data[1]:
title = data[1][0]
summary_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{requests.utils.quote(title)}"
s_resp = requests.get(summary_url)
if s_resp.status_code == 200:
s_data = s_resp.json()
if "extract" in s_data:
print("[LOG] Wikipedia summary fetched successfully.")
return s_data["extract"]
return ""
except Exception as e:
print(f"[ERROR] Exception during Wikipedia summary fetch: {e}")
return ""
def fetch_rss_feed(feed_url: str) -> list:
"""
Pulls RSS feed data from a given URL and returns items.
"""
print("[LOG] Fetching RSS feed:", feed_url)
try:
resp = requests.get(feed_url)
if resp.status_code != 200:
print(f"[ERROR] Failed to fetch RSS feed: {feed_url}")
return []
soup = BeautifulSoup(resp.content, "xml")
items = soup.find_all("item")
return items
except Exception as e:
print(f"[ERROR] Exception fetching RSS feed {feed_url}: {e}")
return []
def find_relevant_article(items, topic: str, min_match=2) -> tuple:
"""
Check each article in the RSS feed for mention of the topic
by counting the number of keyword matches.
"""
print("[LOG] Finding relevant articles...")
keywords = re.findall(r'\w+', topic.lower())
for item in items:
title = item.find("title").get_text().strip() if item.find("title") else ""
description = item.find("description").get_text().strip() if item.find("description") else ""
text = (title + " " + description).lower()
matches = sum(1 for kw in keywords if kw in text)
if matches >= min_match:
link = item.find("link").get_text().strip() if item.find("link") else ""
print(f"[LOG] Relevant article found: {title}")
return title, description, link
return None, None, None
def fetch_article_text(link: str) -> str:
"""
Fetch the article text from the given link (first 5 paragraphs).
"""
print("[LOG] Fetching article text from:", link)
if not link:
print("[LOG] No link provided for article text.")
return ""
try:
resp = requests.get(link)
if resp.status_code != 200:
print(f"[ERROR] Failed to fetch article from {link}")
return ""
soup = BeautifulSoup(resp.text, 'html.parser')
paragraphs = soup.find_all("p")
text = " ".join(p.get_text() for p in paragraphs[:5]) # first 5 paragraphs
print("[LOG] Article text fetched successfully.")
return text.strip()
except Exception as e:
print(f"[ERROR] Error fetching article text: {e}")
return ""
# ---------------------------------------------------------------------
# Pass host_name & guest_name so we can do "female voice" vs "male voice"
# and display_speaker vs. speaker
# ---------------------------------------------------------------------
def generate_script(system_prompt: str, input_text: str, tone: str, target_length: str,
host_name: str = "Jane", guest_name: str = "John"):
"""
Sends the system_prompt plus input_text to the Groq LLM to generate a
multi-speaker Dialogue in JSON. We parse and return it as a Dialogue object.
Logic:
- We parse the LLM's raw speaker name (e.g., "Angela", "Dimitris").
- If it matches the host_name, we set speaker="Jane" (female voice),
display_speaker = host_name.
- If it matches the guest_name, we set speaker="John" (male voice),
display_speaker = guest_name.
- If we can't match, default to "Jane" for speaker, but keep display_speaker as whatever LLM returned.
"""
print("[LOG] Generating script with tone:", tone, "and length:", target_length)
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Instead of a fixed mapping, parse numeric minutes from target_length if possible
words_per_minute = 150
numeric_minutes = 3
match = re.search(r"(\d+)", target_length)
if match:
numeric_minutes = int(match.group(1))
min_words = max(50, numeric_minutes * 100) # rough lower bound
max_words = numeric_minutes * words_per_minute
tone_description = {
"Humorous": "funny and exciting, makes people chuckle",
"Formal": "business-like, well-structured, professional",
"Casual": "like a conversation between close friends, relaxed and informal",
"Youthful": "like how teenagers might chat, energetic and lively"
}
chosen_tone = tone_description.get(tone, "casual")
# Construct prompt
prompt = (
f"{system_prompt}\n"
f"TONE: {chosen_tone}\n"
f"TARGET LENGTH: {target_length} (~{min_words}-{max_words} words)\n"
f"INPUT TEXT: {input_text}\n\n"
"Please provide the output in the following JSON format without any additional text:\n\n"
"{\n"
' "dialogue": [\n'
' {\n'
' "speaker": "Jane",\n'
' "text": "..." \n'
' },\n'
' {\n'
' "speaker": "John",\n'
' "text": "..." \n'
' }\n'
" ]\n"
"}"
)
print("[LOG] Sending prompt to Groq:")
print(prompt)
try:
response = groq_client.chat.completions.create(
messages=[{"role": "system", "content": prompt}],
model="llama-3.3-70b-versatile",
max_tokens=2048,
temperature=0.7
)
except Exception as e:
print("[ERROR] Groq API error:", e)
raise ValueError(f"Error communicating with Groq API: {str(e)}")
raw_content = response.choices[0].message.content.strip()
start_index = raw_content.find('{')
end_index = raw_content.rfind('}')
if start_index == -1 or end_index == -1:
raise ValueError("Failed to parse dialogue: No JSON found.")
json_str = raw_content[start_index:end_index+1].strip()
try:
data = json.loads(json_str)
dialogue_list = data.get("dialogue", [])
# Post-process to ensure correct TTS speaker + custom display name
for d in dialogue_list:
raw_speaker = d.get("speaker", "Jane")
text_line = d.get("text", "")
# If raw_speaker matches host_name (case-insensitive), speaker = "Jane"
if raw_speaker.lower() == host_name.lower():
d["speaker"] = "Jane"
d["display_speaker"] = host_name
# If raw_speaker matches guest_name, speaker = "John"
elif raw_speaker.lower() == guest_name.lower():
d["speaker"] = "John"
d["display_speaker"] = guest_name
else:
# Otherwise default: we assume it's host
d["speaker"] = "Jane"
d["display_speaker"] = raw_speaker # keep the original name for display
# Now build the Dialogue object
# For any item that doesn't have display_speaker, fallback to "Jane"
new_dialogue_items = []
for d in dialogue_list:
if "display_speaker" not in d:
d["display_speaker"] = d["speaker"] # fallback
# Convert dict -> DialogueItem
new_dialogue_items.append(DialogueItem(**d))
return Dialogue(dialogue=new_dialogue_items)
except json.JSONDecodeError as e:
print("[ERROR] JSON decoding (format) failed:", e)
raise ValueError(f"Failed to parse dialogue: {str(e)}")
except Exception as e:
print("[ERROR] JSON decoding failed:", e)
raise ValueError(f"Failed to parse dialogue: {str(e)}")
def transcribe_youtube_video(video_url: str) -> str:
"""
Transcribe the given YouTube video by calling the RapidAPI 'youtube-transcriptor' endpoint.
1) Extract the 11-char video ID from the YouTube URL.
2) Call the RapidAPI endpoint (lang=en).
3) Parse and extract 'transcriptionAsText' from the response.
4) Return that transcript as a string.
"""
print("[LOG] Transcribing YouTube video via RapidAPI:", video_url)
# Extract video ID
video_id_match = re.search(r"(?:v=|\/)([0-9A-Za-z_-]{11})", video_url)
if not video_id_match:
raise ValueError(f"Invalid YouTube URL: {video_url}, cannot extract video ID.")
video_id = video_id_match.group(1)
print("[LOG] Extracted video ID:", video_id)
base_url = "https://youtube-transcriptor.p.rapidapi.com/transcript"
params = {
"video_id": video_id,
"lang": "en"
}
headers = {
"x-rapidapi-host": "youtube-transcriptor.p.rapidapi.com",
"x-rapidapi-key": os.environ.get("RAPIDAPI_KEY")
}
try:
response = requests.get(base_url, headers=headers, params=params, timeout=30)
print("[LOG] RapidAPI Response Status Code:", response.status_code)
print("[LOG] RapidAPI Response Body:", response.text)
if response.status_code != 200:
raise ValueError(f"RapidAPI transcription error: {response.status_code}, {response.text}")
data = response.json()
if not isinstance(data, list) or not data:
raise ValueError(f"Unexpected transcript format or empty transcript: {data}")
transcript_as_text = data[0].get('transcriptionAsText', '').strip()
if not transcript_as_text:
raise ValueError("transcriptionAsText field is missing or empty.")
print("[LOG] Transcript retrieval successful.")
print(f"[DEBUG] Transcript Length: {len(transcript_as_text)} characters.")
if len(transcript_as_text) > 200:
snippet = transcript_as_text[:200] + "..."
else:
snippet = transcript_as_text
print(f"[DEBUG] Transcript Snippet: {snippet}")
return transcript_as_text
except Exception as e:
print("[ERROR] RapidAPI transcription error:", e)
raise ValueError(f"Error transcribing YouTube video via RapidAPI: {str(e)}")
def generate_audio_mp3(text: str, speaker: str) -> str:
"""
Calls Deepgram TTS with the text, returning a path to a temp MP3 file.
We also do some pre-processing for punctuation, abbreviations, etc.
"""
try:
print(f"[LOG] Generating audio for speaker: {speaker}")
# Preprocess text with speaker context
processed_text = _preprocess_text_for_tts(text, speaker)
# Deepgram TTS endpoint
deepgram_api_url = "https://api.deepgram.com/v1/speak"
params = {
"model": "aura-asteria-en", # default female
}
# If speaker == "John", use male voice
if speaker == "John":
params["model"] = "aura-zeus-en"
headers = {
"Accept": "audio/mpeg",
"Content-Type": "application/json",
"Authorization": f"Token {os.environ.get('DEEPGRAM_API_KEY')}"
}
body = {
"text": processed_text
}
response = requests.post(deepgram_api_url, params=params, headers=headers, json=body, stream=True)
if response.status_code != 200:
raise ValueError(f"Deepgram TTS error: {response.status_code}, {response.text}")
content_type = response.headers.get('Content-Type', '')
if 'audio/mpeg' not in content_type:
raise ValueError("Unexpected Content-Type from Deepgram.")
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as mp3_file:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
mp3_file.write(chunk)
mp3_path = mp3_file.name
# Normalize volume
audio_seg = AudioSegment.from_file(mp3_path, format="mp3")
audio_seg = effects.normalize(audio_seg)
final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
audio_seg.export(final_mp3_path, format="mp3")
if os.path.exists(mp3_path):
os.remove(mp3_path)
return final_mp3_path
except Exception as e:
print("[ERROR] Error generating audio:", e)
raise ValueError(f"Error generating audio: {str(e)}")
def transcribe_youtube_video_OLD_YTDLP(video_url: str) -> str:
"""
Original ytdlp-based approach for local transcription.
No longer used, but kept for reference.
"""
pass
def _preprocess_text_for_tts(text: str, speaker: str) -> str:
"""
Enhances text for natural-sounding TTS by handling abbreviations,
punctuation, and intelligent filler insertion.
Adjustments are made based on the speaker to optimize output quality.
New: We'll handle "SaaS" so that it is read as "S A A S".
"""
# 1) Hyphens -> spaces
text = re.sub(r"-", " ", text)
# 2) Convert decimals (e.g., 3.14 -> 'three point one four')
def convert_decimal(m):
number_str = m.group()
parts = number_str.split('.')
whole_part = _spell_digits(parts[0])
decimal_part = " ".join(_spell_digits(d) for d in parts[1])
return f"{whole_part} point {decimal_part}"
text = re.sub(r"\d+\.\d+", convert_decimal, text)
# 3) Abbreviations (e.g., NASA -> N A S A).
# We'll also handle "SaaS" -> "S A A S" specifically.
def expand_abbreviations(match):
abbrev = match.group()
# Special handling for "SaaS" -> "S A A S"
if abbrev.lower() == "saas":
return "S A A S"
# Check if it's plural with capital letters
if abbrev.endswith('s') and abbrev[:-1].isupper():
singular = abbrev[:-1]
expanded = " ".join(list(singular)) + "s"
specific_plural = {
"MPs": "M Peas",
"TMTs": "T M Tees",
"ARJs": "A R Jays",
}
return specific_plural.get(abbrev, expanded)
else:
return " ".join(list(abbrev))
text = re.sub(r"\b[A-Z]{2,}s?\b", expand_abbreviations, text)
# 5) Intelligent filler insertion after specific keywords (skip for Jane)
if speaker != "Jane":
def insert_thinking_pause(m):
word = m.group(1)
if random.random() < 0.3: # 30% chance
filler = random.choice(['hmm,', 'well,', 'let me see,'])
return f"{word}..., {filler}"
else:
return f"{word}...,"
keywords_pattern = r"\b(important|significant|crucial|point|topic)\b"
text = re.sub(keywords_pattern, insert_thinking_pause, text, flags=re.IGNORECASE)
# 6) Insert dynamic pauses within sentences (for non-Jane speakers)
if speaker != "Jane":
conjunctions_pattern = r"\b(and|but|so|because|however)\b"
text = re.sub(conjunctions_pattern, lambda m: f"{m.group()}...", text, flags=re.IGNORECASE)
# 7) Remove any unintended random fillers (safeguard)
text = re.sub(r"\b(uh|um|ah)\b", "", text, flags=re.IGNORECASE)
# 8) Ensure normal grammar and speaking style
def capitalize_match(match):
return match.group().upper()
text = re.sub(r'(^\s*\w)|([.!?]\s*\w)', capitalize_match, text)
return text.strip()
def _spell_digits(d: str) -> str:
"""
Convert digits '3' -> 'three', etc.
"""
digit_map = {
'0': 'zero',
'1': 'one',
'2': 'two',
'3': 'three',
'4': 'four',
'5': 'five',
'6': 'six',
'7': 'seven',
'8': 'eight',
'9': 'nine'
}
return " ".join(digit_map[ch] for ch in d if ch in digit_map)
def mix_with_bg_music(spoken: AudioSegment, custom_music_path=None) -> AudioSegment:
"""
Mixes 'spoken' with a default bg_music.mp3 or a user-provided custom music file:
1) Start with 2 seconds of music alone before speech begins.
2) Loop the music if it's shorter than the final audio length.
3) Lower the music volume so the speech is clear.
"""
if custom_music_path:
music_path = custom_music_path
else:
music_path = "bg_music.mp3"
try:
bg_music = AudioSegment.from_file(music_path, format="mp3")
except Exception as e:
print("[ERROR] Failed to load background music:", e)
return spoken
# Reduce background music volume
bg_music = bg_music - 18.0
total_length_ms = len(spoken) + 2000
looped_music = AudioSegment.empty()
while len(looped_music) < total_length_ms:
looped_music += bg_music
looped_music = looped_music[:total_length_ms]
final_mix = looped_music.overlay(spoken, position=2000)
return final_mix