Spaces:
Running
Running
# utils.py | |
import os | |
import re | |
import json | |
import requests | |
import tempfile | |
from bs4 import BeautifulSoup | |
from typing import List, Literal | |
from pydantic import BaseModel | |
from pydub import AudioSegment, effects | |
from transformers import pipeline | |
import yt_dlp | |
import tiktoken | |
from groq import Groq | |
import numpy as np | |
import torch | |
import random | |
class DialogueItem(BaseModel): | |
speaker: Literal["Jane", "John"] | |
display_speaker: str = "Jane" | |
text: str | |
class Dialogue(BaseModel): | |
dialogue: List[DialogueItem] | |
# Not used for YouTube, but for local if needed | |
asr_pipeline = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-tiny.en", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
def truncate_text(text, max_tokens=2048): | |
print("[LOG] Truncating text if needed.") | |
tokenizer = tiktoken.get_encoding("cl100k_base") | |
tokens = tokenizer.encode(text) | |
if len(tokens) > max_tokens: | |
print("[LOG] Text too long, truncating.") | |
return tokenizer.decode(tokens[:max_tokens]) | |
return text | |
def extract_text_from_url(url): | |
print("[LOG] Extracting text from URL:", url) | |
try: | |
headers = { | |
"User-Agent": ( | |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) " | |
"AppleWebKit/537.36 (KHTML, like Gecko) " | |
"Chrome/115.0.0.0 Safari/537.36" | |
) | |
} | |
response = requests.get(url, headers=headers) | |
if response.status_code != 200: | |
print(f"[ERROR] Failed to fetch URL: {url} with status code {response.status_code}") | |
return "" | |
soup = BeautifulSoup(response.text, 'html.parser') | |
for script in soup(["script", "style"]): | |
script.decompose() | |
text = soup.get_text(separator=' ') | |
print("[LOG] Text extraction from URL successful.") | |
return text | |
except Exception as e: | |
print(f"[ERROR] Exception during text extraction from URL: {e}") | |
return "" | |
def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment: | |
print(f"[LOG] Shifting pitch by {semitones} semitones.") | |
new_sample_rate = int(audio.frame_rate * (2.0 ** (semitones / 12.0))) | |
shifted_audio = audio._spawn(audio.raw_data, overrides={'frame_rate': new_sample_rate}) | |
return shifted_audio.set_frame_rate(audio.frame_rate) | |
def is_sufficient(text: str, min_word_count: int = 500) -> bool: | |
word_count = len(text.split()) | |
print(f"[DEBUG] Aggregated word count: {word_count}") | |
return word_count >= min_word_count | |
def query_llm_for_additional_info(topic: str, existing_text: str) -> str: | |
print("[LOG] Querying LLM for additional information.") | |
system_prompt = ( | |
"You are an AI assistant with extensive knowledge up to 2023-10. " | |
"Provide additional relevant information on the following topic based on your knowledge base.\n\n" | |
f"Topic: {topic}\n\n" | |
f"Existing Information: {existing_text}\n\n" | |
"Please add more insightful details, facts, and perspectives to enhance the understanding of the topic." | |
) | |
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
try: | |
response = groq_client.chat.completions.create( | |
messages=[{"role": "system", "content": system_prompt}], | |
model="llama-3.3-70b-versatile", | |
max_tokens=1024, | |
temperature=0.7 | |
) | |
except Exception as e: | |
print("[ERROR] Groq API error during fallback:", e) | |
return "" | |
additional_info = response.choices[0].message.content.strip() | |
print("[DEBUG] Additional information from LLM:") | |
print(additional_info) | |
return additional_info | |
def research_topic(topic: str) -> str: | |
sources = { | |
"BBC": "https://feeds.bbci.co.uk/news/rss.xml", | |
"CNN": "http://rss.cnn.com/rss/edition.rss", | |
"Associated Press": "https://apnews.com/apf-topnews", | |
"NDTV": "https://www.ndtv.com/rss/top-stories", | |
"Times of India": "https://timesofindia.indiatimes.com/rssfeeds/296589292.cms", | |
"The Hindu": "https://www.thehindu.com/news/national/kerala/rssfeed.xml", | |
"Economic Times": "https://economictimes.indiatimes.com/rssfeeds/1977021501.cms", | |
"Google News - Custom": f"https://news.google.com/rss/search?q={requests.utils.quote(topic)}&hl=en-IN&gl=IN&ceid=IN:en", | |
} | |
summary_parts = [] | |
# Wikipedia summary | |
wiki_summary = fetch_wikipedia_summary(topic) | |
if wiki_summary: | |
summary_parts.append(f"From Wikipedia: {wiki_summary}") | |
for name, feed_url in sources.items(): | |
try: | |
items = fetch_rss_feed(feed_url) | |
if not items: | |
continue | |
title, desc, link = find_relevant_article(items, topic, min_match=2) | |
if link: | |
article_text = fetch_article_text(link) | |
if article_text: | |
summary_parts.append(f"From {name}: {article_text}") | |
else: | |
summary_parts.append(f"From {name}: {title} - {desc}") | |
except Exception as e: | |
print(f"[ERROR] Error fetching from {name} RSS feed:", e) | |
continue | |
aggregated_info = " ".join(summary_parts) | |
print("[DEBUG] Aggregated info from primary sources:") | |
print(aggregated_info) | |
if not is_sufficient(aggregated_info): | |
print("[LOG] Insufficient info from primary sources. Fallback to LLM.") | |
additional_info = query_llm_for_additional_info(topic, aggregated_info) | |
if additional_info: | |
aggregated_info += " " + additional_info | |
else: | |
print("[ERROR] Failed to retrieve additional info from LLM.") | |
if not aggregated_info: | |
return f"Sorry, I couldn't find recent information on '{topic}'." | |
return aggregated_info | |
def fetch_wikipedia_summary(topic: str) -> str: | |
print("[LOG] Fetching Wikipedia summary for:", topic) | |
try: | |
search_url = ( | |
f"https://en.wikipedia.org/w/api.php?action=opensearch&search={requests.utils.quote(topic)}" | |
"&limit=1&namespace=0&format=json" | |
) | |
resp = requests.get(search_url) | |
if resp.status_code != 200: | |
print(f"[ERROR] Failed to fetch Wikipedia search results for {topic}") | |
return "" | |
data = resp.json() | |
if len(data) > 1 and data[1]: | |
title = data[1][0] | |
summary_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{requests.utils.quote(title)}" | |
s_resp = requests.get(summary_url) | |
if s_resp.status_code == 200: | |
s_data = s_resp.json() | |
if "extract" in s_data: | |
print("[LOG] Wikipedia summary fetched successfully.") | |
return s_data["extract"] | |
return "" | |
except Exception as e: | |
print(f"[ERROR] Exception during Wikipedia summary fetch: {e}") | |
return "" | |
def fetch_rss_feed(feed_url: str) -> list: | |
print("[LOG] Fetching RSS feed:", feed_url) | |
try: | |
resp = requests.get(feed_url) | |
if resp.status_code != 200: | |
print(f"[ERROR] Failed to fetch RSS feed: {feed_url}") | |
return [] | |
soup = BeautifulSoup(resp.content, "xml") | |
items = soup.find_all("item") | |
return items | |
except Exception as e: | |
print(f"[ERROR] Exception fetching RSS feed {feed_url}: {e}") | |
return [] | |
def find_relevant_article(items, topic: str, min_match=2) -> tuple: | |
print("[LOG] Finding relevant articles...") | |
keywords = re.findall(r'\w+', topic.lower()) | |
for item in items: | |
title = item.find("title").get_text().strip() if item.find("title") else "" | |
description = item.find("description").get_text().strip() if item.find("description") else "" | |
text = (title + " " + description).lower() | |
matches = sum(1 for kw in keywords if kw in text) | |
if matches >= min_match: | |
link = item.find("link").get_text().strip() if item.find("link") else "" | |
print(f"[LOG] Relevant article found: {title}") | |
return title, description, link | |
return None, None, None | |
def fetch_article_text(link: str) -> str: | |
print("[LOG] Fetching article text from:", link) | |
if not link: | |
print("[LOG] No link provided for article text.") | |
return "" | |
try: | |
resp = requests.get(link) | |
if resp.status_code != 200: | |
print(f"[ERROR] Failed to fetch article from {link}") | |
return "" | |
soup = BeautifulSoup(resp.text, 'html.parser') | |
paragraphs = soup.find_all("p") | |
text = " ".join(p.get_text() for p in paragraphs[:5]) # first 5 paragraphs | |
print("[LOG] Article text fetched successfully.") | |
return text.strip() | |
except Exception as e: | |
print(f"[ERROR] Error fetching article text: {e}") | |
return "" | |
def generate_script( | |
system_prompt: str, | |
input_text: str, | |
tone: str, | |
target_length: str, | |
host_name: str = "Jane", | |
guest_name: str = "John", | |
sponsor_style: str = "Separate Break" | |
): | |
""" | |
If sponsor content is empty, we won't have sponsor instructions appended in app.py's prompt. | |
So the LLM should not generate sponsor segments. | |
""" | |
print("[LOG] Generating script with tone:", tone, "and length:", target_length) | |
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
words_per_minute = 150 | |
numeric_minutes = 3 | |
match = re.search(r"(\d+)", target_length) | |
if match: | |
numeric_minutes = int(match.group(1)) | |
min_words = max(50, numeric_minutes * 100) | |
max_words = numeric_minutes * words_per_minute | |
tone_map = { | |
"Humorous": "funny and exciting, makes people chuckle", | |
"Formal": "business-like, well-structured, professional", | |
"Casual": "like a conversation between close friends, relaxed and informal", | |
"Youthful": "like how teenagers might chat, energetic and lively" | |
} | |
chosen_tone = tone_map.get(tone, "casual") | |
if sponsor_style == "Separate Break": | |
sponsor_instructions = ( | |
"If sponsor content is provided, include it in a separate ad break (~30 seconds). " | |
"Use phrasing like 'Now a word from our sponsor...' and end with 'Back to the show' or similar." | |
) | |
else: | |
sponsor_instructions = ( | |
"If sponsor content is provided, blend it naturally (~30 seconds) into the conversation. " | |
"Avoid abrupt transitions." | |
) | |
prompt = ( | |
f"{system_prompt}\n" | |
f"TONE: {chosen_tone}\n" | |
f"TARGET LENGTH: {target_length} (~{min_words}-{max_words} words)\n" | |
f"INPUT TEXT: {input_text}\n\n" | |
f"# Sponsor Style Instruction:\n{sponsor_instructions}\n\n" | |
"Please provide the output in the following JSON format without any additional text:\n\n" | |
"{\n" | |
' "dialogue": [\n' | |
' {\n' | |
' "speaker": "Jane",\n' | |
' "text": "..." \n' | |
' },\n' | |
' {\n' | |
' "speaker": "John",\n' | |
' "text": "..." \n' | |
' }\n' | |
" ]\n" | |
"}" | |
) | |
print("[LOG] Sending prompt to Groq:") | |
print(prompt) | |
try: | |
response = groq_client.chat.completions.create( | |
messages=[{"role": "system", "content": prompt}], | |
model="llama-3.3-70b-versatile", | |
max_tokens=2048, | |
temperature=0.7 | |
) | |
except Exception as e: | |
print("[ERROR] Groq API error:", e) | |
raise ValueError(f"Error communicating with Groq API: {str(e)}") | |
raw_content = response.choices[0].message.content.strip() | |
start_index = raw_content.find('{') | |
end_index = raw_content.rfind('}') | |
if start_index == -1 or end_index == -1: | |
raise ValueError("Failed to parse dialogue: No JSON found.") | |
json_str = raw_content[start_index:end_index+1].strip() | |
try: | |
data = json.loads(json_str) | |
dialogue_list = data.get("dialogue", []) | |
for d in dialogue_list: | |
raw_speaker = d.get("speaker", "Jane") | |
if raw_speaker.lower() == host_name.lower(): | |
d["speaker"] = "Jane" | |
d["display_speaker"] = host_name | |
elif raw_speaker.lower() == guest_name.lower(): | |
d["speaker"] = "John" | |
d["display_speaker"] = guest_name | |
else: | |
d["speaker"] = "Jane" | |
d["display_speaker"] = raw_speaker | |
new_dialogue_items = [] | |
for d in dialogue_list: | |
if "display_speaker" not in d: | |
d["display_speaker"] = d["speaker"] | |
new_dialogue_items.append(DialogueItem(**d)) | |
return Dialogue(dialogue=new_dialogue_items) | |
except json.JSONDecodeError as e: | |
print("[ERROR] JSON decoding (format) failed:", e) | |
raise ValueError(f"Failed to parse dialogue: {str(e)}") | |
except Exception as e: | |
print("[ERROR] JSON decoding failed:", e) | |
raise ValueError(f"Failed to parse dialogue: {str(e)}") | |
def transcribe_youtube_video(video_url: str) -> str: | |
print("[LOG] Transcribing YouTube video via RapidAPI:", video_url) | |
video_id_match = re.search(r"(?:v=|\/)([0-9A-Za-z_-]{11})", video_url) | |
if not video_id_match: | |
raise ValueError(f"Invalid YouTube URL: {video_url}, cannot extract video ID.") | |
video_id = video_id_match.group(1) | |
print("[LOG] Extracted video ID:", video_id) | |
base_url = "https://youtube-transcriptor.p.rapidapi.com/transcript" | |
params = { | |
"video_id": video_id, | |
"lang": "en" | |
} | |
headers = { | |
"x-rapidapi-host": "youtube-transcriptor.p.rapidapi.com", | |
"x-rapidapi-key": os.environ.get("RAPIDAPI_KEY") | |
} | |
try: | |
response = requests.get(base_url, headers=headers, params=params, timeout=30) | |
print("[LOG] RapidAPI Response Status Code:", response.status_code) | |
print("[LOG] RapidAPI Response Body:", response.text) | |
if response.status_code != 200: | |
raise ValueError(f"RapidAPI transcription error: {response.status_code}, {response.text}") | |
data = response.json() | |
if not isinstance(data, list) or not data: | |
raise ValueError(f"Unexpected transcript format or empty transcript: {data}") | |
transcript_as_text = data[0].get('transcriptionAsText', '').strip() | |
if not transcript_as_text: | |
raise ValueError("transcriptionAsText field is missing or empty.") | |
print("[LOG] Transcript retrieval successful.") | |
print(f"[DEBUG] Transcript Length: {len(transcript_as_text)} characters.") | |
snippet = transcript_as_text[:200] + "..." if len(transcript_as_text) > 200 else transcript_as_text | |
print(f"[DEBUG] Transcript Snippet: {snippet}") | |
return transcript_as_text | |
except Exception as e: | |
print("[ERROR] RapidAPI transcription error:", e) | |
raise ValueError(f"Error transcribing YouTube video via RapidAPI: {str(e)}") | |
def generate_audio_mp3(text: str, speaker: str) -> str: | |
""" | |
Calls Deepgram TTS with the text, returning a path to a temp MP3 file. | |
Then we do normal volume normalization, etc. | |
""" | |
try: | |
print(f"[LOG] Generating audio for speaker: {speaker}") | |
processed_text = _preprocess_text_for_tts(text, speaker) | |
deepgram_api_url = "https://api.deepgram.com/v1/speak" | |
params = { | |
"model": "aura-asteria-en", # female by default | |
} | |
if speaker == "John": | |
params["model"] = "aura-helios-en" | |
headers = { | |
"Accept": "audio/mpeg", | |
"Content-Type": "application/json", | |
"Authorization": f"Token {os.environ.get('DEEPGRAM_API_KEY')}" | |
} | |
body = { | |
"text": processed_text | |
} | |
response = requests.post(deepgram_api_url, params=params, headers=headers, json=body, stream=True) | |
if response.status_code != 200: | |
raise ValueError(f"Deepgram TTS error: {response.status_code}, {response.text}") | |
content_type = response.headers.get('Content-Type', '') | |
if 'audio/mpeg' not in content_type: | |
raise ValueError("Unexpected Content-Type from Deepgram.") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as mp3_file: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
mp3_file.write(chunk) | |
mp3_path = mp3_file.name | |
audio_seg = AudioSegment.from_file(mp3_path, format="mp3") | |
audio_seg = effects.normalize(audio_seg) | |
final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name | |
audio_seg.export(final_mp3_path, format="mp3") | |
if os.path.exists(mp3_path): | |
os.remove(mp3_path) | |
return final_mp3_path | |
except Exception as e: | |
print("[ERROR] Error generating audio:", e) | |
raise ValueError(f"Error generating audio: {str(e)}") | |
def transcribe_youtube_video_OLD_YTDLP(video_url: str) -> str: | |
pass | |
def _preprocess_text_for_tts(text: str, speaker: str) -> str: | |
""" | |
1) "SaaS" => "sass" | |
2) Insert periods for uppercase abbreviations -> remove for TTS (N.I.A. => N I A) | |
3) Convert decimals (3.14 => 'three point one four') | |
4) Convert integers (10 => 'ten', 4000 => 'four thousand') | |
5) Expand leftover all-caps | |
6) Emotive placeholders for 'ha', 'haha', 'sigh', 'groan', etc. | |
7) If speaker == "John", insert short breath "..." after punctuation (not random mid-word) | |
8) Remove random fillers (uh, um) | |
9) Capitalize sentence starts | |
""" | |
# 1) "SaaS" => "sass" | |
text = re.sub(r"\b(?i)SaaS\b", "sass", text) | |
# 2) Insert periods for uppercase abbreviations => remove them | |
def insert_periods_for_abbrev(m): | |
abbr = m.group(0) | |
parted = ".".join(list(abbr)) + "." | |
return parted | |
text = re.sub(r"\b([A-Z0-9]{2,})\b", insert_periods_for_abbrev, text) | |
text = re.sub(r"\.\.", ".", text) | |
def remove_periods_for_tts(m): | |
# "N.I.A." => "N I A" | |
chunk = m.group(0) | |
return chunk.replace(".", " ").strip() | |
text = re.sub(r"[A-Z0-9]\.[A-Z0-9](?:\.[A-Z0-9])*\.", remove_periods_for_tts, text) | |
# 3) Hyphens -> spaces | |
text = re.sub(r"-", " ", text) | |
# 4) Convert decimals | |
def convert_decimal(m): | |
number_str = m.group() | |
parts = number_str.split('.') | |
whole_part = _spell_digits(parts[0]) | |
decimal_part = " ".join(_spell_digits(d) for d in parts[1]) | |
return f"{whole_part} point {decimal_part}" | |
text = re.sub(r"\b\d+\.\d+\b", convert_decimal, text) | |
# Convert pure integers => words | |
def convert_int_to_words(m): | |
num_str = m.group() | |
return number_to_words(int(num_str)) | |
text = re.sub(r"\b\d+\b", convert_int_to_words, text) | |
# 5) Expand leftover all-caps => "NASA" => "N A S A" | |
def expand_abbreviations(m): | |
abbrev = m.group() | |
if abbrev.endswith('s') and abbrev[:-1].isupper(): | |
singular = abbrev[:-1] | |
expanded = " ".join(list(singular)) + "s" | |
special_plurals = { | |
"MPs": "M Peas", | |
"TMTs": "T M Tees", | |
"ARJs": "A R Jays", | |
} | |
return special_plurals.get(abbrev, expanded) | |
else: | |
return " ".join(list(abbrev)) | |
text = re.sub(r"\b[A-Z]{2,}s?\b", expand_abbreviations, text) | |
# 6) Emotive placeholders | |
text = re.sub(r"\b(ha(ha)?|heh|lol)\b", "(* laughs *)", text, flags=re.IGNORECASE) | |
text = re.sub(r"\bsigh\b", "(* sighs *)", text, flags=re.IGNORECASE) | |
text = re.sub(r"\b(groan|moan)\b", "(* groans *)", text, flags=re.IGNORECASE) | |
# 7) If speaker == "John", place short "..." after punctuation only | |
if speaker == "John": | |
# Insert a short "..." after . , ! ? ; : | |
text = re.sub(r"([.,!?;:])(\s|$)", r"\1...\2", text) | |
# 8) Remove random fillers | |
text = re.sub(r"\b(uh|um|ah)\b", "", text, flags=re.IGNORECASE) | |
# 9) Capitalize sentence starts | |
def capitalize_match(m): | |
return m.group().upper() | |
text = re.sub(r'(^\s*\w)|([.!?]\s*\w)', capitalize_match, text) | |
return text.strip() | |
def _spell_digits(d: str) -> str: | |
""" | |
Convert individual digits '3' -> 'three'. | |
""" | |
digit_map = { | |
'0': 'zero', | |
'1': 'one', | |
'2': 'two', | |
'3': 'three', | |
'4': 'four', | |
'5': 'five', | |
'6': 'six', | |
'7': 'seven', | |
'8': 'eight', | |
'9': 'nine' | |
} | |
return " ".join(digit_map[ch] for ch in d if ch in digit_map) | |
def number_to_words(n: int) -> str: | |
""" | |
Enhanced integer-to-words up to 999,999 so '10' => 'ten', '4000' => 'four thousand'. | |
""" | |
if n == 0: | |
return "zero" | |
if n < 0: | |
return "minus " + number_to_words(-n) | |
ones = ["","one","two","three","four","five","six","seven","eight","nine"] | |
teens = ["ten","eleven","twelve","thirteen","fourteen","fifteen","sixteen","seventeen","eighteen","nineteen"] | |
tens_words = ["","","twenty","thirty","forty","fifty","sixty","seventy","eighty","ninety"] | |
def three_digits(x): | |
w = [] | |
hundreds = x // 100 | |
rem = x % 100 | |
if hundreds > 0: | |
w.append(ones[hundreds]) | |
w.append("hundred") | |
if rem > 0: | |
w.append("and") | |
if rem < 10 and rem > 0: | |
w.append(ones[rem]) | |
elif rem >= 10 and rem < 20: | |
w.append(teens[rem - 10]) | |
else: | |
t = rem // 10 | |
o = rem % 10 | |
if t > 1: | |
w.append(tens_words[t]) | |
if o > 0: | |
w.append(ones[o]) | |
return " ".join(i for i in w if i) | |
thousands = n // 1000 | |
remainder = n % 1000 | |
parts = [] | |
if thousands > 0: | |
parts.append(three_digits(thousands)) | |
parts.append("thousand") | |
if remainder > 0: | |
parts.append(three_digits(remainder)) | |
out = " ".join(i for i in parts if i).strip() | |
return out or "zero" | |
def mix_with_bg_music(spoken: AudioSegment, custom_music_path=None) -> AudioSegment: | |
""" | |
Mixes 'spoken' with bg_music.mp3 or custom music: | |
- 2s lead-in | |
- Loop if shorter | |
- Lower volume | |
""" | |
if custom_music_path: | |
music_path = custom_music_path | |
else: | |
music_path = "bg_music.mp3" | |
try: | |
bg_music = AudioSegment.from_file(music_path, format="mp3") | |
except Exception as e: | |
print("[ERROR] Failed to load background music:", e) | |
return spoken | |
bg_music = bg_music - 18.0 | |
total_length_ms = len(spoken) + 2000 | |
looped_music = AudioSegment.empty() | |
while len(looped_music) < total_length_ms: | |
looped_music += bg_music | |
looped_music = looped_music[:total_length_ms] | |
final_mix = looped_music.overlay(spoken, position=2000) | |
return final_mix | |
def call_groq_api_for_qa(system_prompt: str) -> str: | |
""" | |
Minimal function for short Q&A calls. Must return JSON: | |
{ "speaker": "John", "text": "Short answer" } | |
""" | |
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
try: | |
response = groq_client.chat.completions.create( | |
messages=[{"role": "system", "content": system_prompt}], | |
model="llama-3.3-70b-versatile", | |
max_tokens=512, | |
temperature=0.7 | |
) | |
except Exception as e: | |
print("[ERROR] Groq API error:", e) | |
fallback = {"speaker": "John", "text": "I'm sorry, I'm having trouble answering right now."} | |
return json.dumps(fallback) | |
raw_content = response.choices[0].message.content.strip() | |
return raw_content | |