Spaces:
Running
Running
import os | |
import re | |
import json | |
import requests | |
import tempfile | |
from bs4 import BeautifulSoup | |
from typing import List, Literal, Optional | |
from pydantic import BaseModel | |
from pydub import AudioSegment, effects | |
from transformers import pipeline | |
import tiktoken | |
from groq import Groq # Retained for LLM interaction | |
import numpy as np | |
import torch | |
import random | |
# --- CORRECT IMPORTS --- | |
# No more sys.path modification! | |
from report_structure import generate_report # For report structuring | |
from tavily import TavilyClient | |
class DialogueItem(BaseModel): | |
speaker: Literal["Jane", "John"] | |
display_speaker: str = "Jane" | |
text: str | |
class Dialogue(BaseModel): | |
dialogue: List[DialogueItem] | |
asr_pipeline = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-tiny.en", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
def truncate_text(text, max_tokens=2048): | |
print("[LOG] Truncating text if needed.") | |
tokenizer = tiktoken.get_encoding("cl100k_base") | |
tokens = tokenizer.encode(text) | |
if len(tokens) > max_tokens: | |
print("[LOG] Text too long, truncating.") | |
return tokenizer.decode(tokens[:max_tokens]) | |
return text | |
def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment: | |
print(f"[LOG] Shifting pitch by {semitones} semitones.") | |
new_sample_rate = int(audio.frame_rate * (2.0 ** (semitones / 12.0))) | |
shifted_audio = audio._spawn(audio.raw_data, overrides={'frame_rate': new_sample_rate}) | |
return shifted_audio.set_frame_rate(audio.frame_rate) | |
# --- Functions no longer needed --- | |
# def is_sufficient(...) | |
# def query_llm_for_additional_info(...) | |
# def research_topic(...) | |
# def fetch_wikipedia_summary(...) | |
# def fetch_rss_feed(...) | |
# def find_relevant_article(...) | |
# def fetch_article_text(...) | |
def generate_script( | |
system_prompt: str, | |
input_text: str, | |
tone: str, | |
target_length: str, | |
host_name: str = "Jane", | |
guest_name: str = "John", | |
sponsor_style: str = "Separate Break", | |
sponsor_provided=None | |
): | |
print("[LOG] Generating script with tone:", tone, "and length:", target_length) | |
import streamlit as st # Import streamlit here, where it's used | |
if (host_name == "Jane" or not host_name) and st.session_state.get("language_selection") in ["English (Indian)", "Hinglish", "Hindi"]: | |
host_name = "Isha" | |
if (guest_name == "John" or not guest_name) and st.session_state.get("language_selection") in ["English (Indian)", "Hinglish", "Hindi"]: | |
guest_name = "Aarav" | |
words_per_minute = 150 | |
numeric_minutes = 3 | |
match = re.search(r"(\d+)", target_length) | |
if match: | |
numeric_minutes = int(match.group(1)) | |
min_words = max(50, numeric_minutes * 100) | |
max_words = numeric_minutes * words_per_minute | |
tone_map = { | |
"Humorous": "funny and exciting, makes people chuckle", | |
"Formal": "business-like, well-structured, professional", | |
"Casual": "like a conversation between close friends, relaxed and informal", | |
"Youthful": "like how teenagers might chat, energetic and lively" | |
} | |
chosen_tone = tone_map.get(tone, "casual") | |
if sponsor_provided: | |
if sponsor_style == "Separate Break": | |
sponsor_instructions = ( | |
"If sponsor content is provided, include it in a separate ad break (~30 seconds). " | |
"Use phrasing like 'Now a word from our sponsor...' and end with 'Back to the show' or similar." | |
) | |
else: | |
sponsor_instructions = ( | |
"If sponsor content is provided, blend it naturally (~30 seconds) into the conversation. " | |
"Avoid abrupt transitions." | |
) | |
else: | |
sponsor_instructions = "" | |
prompt = ( | |
f"{system_prompt}\n" | |
f"TONE: {chosen_tone}\n" | |
f"TARGET LENGTH: {target_length} (~{min_words}-{max_words} words)\n" | |
f"INPUT TEXT: {input_text}\n\n" | |
f"# Sponsor Style Instruction:\n{sponsor_instructions}\n\n" | |
"Please provide the output in the following JSON format without any additional text:\n\n" | |
"{\n" | |
' "dialogue": [\n' | |
' {\n' | |
' "speaker": "Jane",\n' | |
' "text": "..." \n' | |
' },\n' | |
' {\n' | |
' "speaker": "John",\n' | |
' "text": "..." \n' | |
' }\n' | |
" ]\n" | |
"}" | |
) | |
print("[LOG] Sending prompt to Deepseek R1 via OpenRouter:") | |
print(prompt) | |
# Add language-specific instructions | |
if st.session_state.get("language_selection") == "Hinglish": | |
prompt += "\n\nPlease generate the script in Romanized Hindi.\n" | |
elif st.session_state.get("language_selection") == "Hindi": | |
prompt += "\n\nPlease generate the script exclusively in Hindi, using only Hindi vocabulary and grammar without any English words or phrases.\n" | |
try: | |
headers = { | |
"Authorization": f"Bearer {os.environ.get('DEEPSEEK_API_KEY')}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": "deepseek/deepseek-r1", | |
"messages": [{"role": "user", "content": prompt}], | |
"max_tokens": 2048, | |
"temperature": 0.7 | |
} | |
response = requests.post("https://openrouter.ai/api/v1/chat/completions", | |
headers=headers, data=json.dumps(data)) | |
response.raise_for_status() | |
raw_content = response.json()["choices"][0]["message"]["content"].strip() | |
except Exception as e: | |
print("[ERROR] Deepseek API error:", e) | |
raise ValueError(f"Error communicating with Deepseek API: {str(e)}") | |
start_index = raw_content.find('{') | |
end_index = raw_content.rfind('}') | |
if start_index == -1 or end_index == -1: | |
raise ValueError("Failed to parse dialogue: No JSON found.") | |
json_str = raw_content[start_index:end_index+1].strip() | |
try: | |
data = json.loads(json_str) | |
dialogue_list = data.get("dialogue", []) | |
for d in dialogue_list: | |
raw_speaker = d.get("speaker", "Jane") | |
if raw_speaker.lower() == host_name.lower(): | |
d["speaker"] = "Jane" | |
d["display_speaker"] = host_name | |
elif raw_speaker.lower() == guest_name.lower(): | |
d["speaker"] = "John" | |
d["display_speaker"] = guest_name | |
else: | |
d["speaker"] = "Jane" | |
d["display_speaker"] = raw_speaker | |
new_dialogue_items = [] | |
for d in dialogue_list: | |
if "display_speaker" not in d: | |
d["display_speaker"] = d["speaker"] | |
new_dialogue_items.append(DialogueItem(**d)) | |
return Dialogue(dialogue=new_dialogue_items) | |
except json.JSONDecodeError as e: | |
print("[ERROR] JSON decoding (format) failed:", e) | |
raise ValueError(f"Failed to parse dialogue: {str(e)}") | |
except Exception as e: | |
print("[ERROR] JSON decoding failed:", e) | |
raise ValueError(f"Failed to parse dialogue: {str(e)}") | |
def transcribe_youtube_video(video_url: str) -> str: | |
print("[LOG] Transcribing YouTube video via RapidAPI:", video_url) | |
video_id_match = re.search(r"(?:v=|\/)([0-9A-Za-z_-]{11})", video_url) | |
if not video_id_match: | |
raise ValueError(f"Invalid YouTube URL: {video_url}, cannot extract video ID.") | |
video_id = video_id_match.group(1) | |
print("[LOG] Extracted video ID:", video_id) | |
base_url = "https://youtube-transcriptor.p.rapidapi.com/transcript" | |
params = {"video_id": video_id, "lang": "en"} | |
headers = { | |
"x-rapidapi-host": "youtube-transcriptor.p.rapidapi.com", | |
"x-rapidapi-key": os.environ.get("RAPIDAPI_KEY") | |
} | |
try: | |
response = requests.get(base_url, headers=headers, params=params, timeout=30) | |
print("[LOG] RapidAPI Response Status Code:", response.status_code) | |
print("[LOG] RapidAPI Response Body:", response.text) | |
if response.status_code != 200: | |
raise ValueError(f"RapidAPI transcription error: {response.status_code}, {response.text}") | |
data = response.json() | |
if not isinstance(data, list) or not data: | |
raise ValueError(f"Unexpected transcript format or empty transcript: {data}") | |
transcript_as_text = data[0].get('transcriptionAsText', '').strip() | |
if not transcript_as_text: | |
raise ValueError("transcriptionAsText field is missing or empty.") | |
print("[LOG] Transcript retrieval successful.") | |
print(f"[DEBUG] Transcript Length: {len(transcript_as_text)} characters.") | |
snippet = transcript_as_text[:200] + "..." if len(transcript_as_text) > 200 else transcript_as_text | |
print(f"[DEBUG] Transcript Snippet: {snippet}") | |
return transcript_as_text | |
except Exception as e: | |
print("[ERROR] RapidAPI transcription error:", e) | |
raise ValueError(f"Error transcribing YouTube video via RapidAPI: {str(e)}") | |
def generate_audio_mp3(text: str, speaker: str) -> str: | |
try: | |
import streamlit as st | |
print(f"[LOG] Generating audio for speaker: {speaker}") | |
language_selection = st.session_state.get("language_selection", "English (American)") | |
if language_selection == "English (American)": | |
print(f"[LOG] Using Deepgram for English (American)") | |
if speaker in ["John", "Jane"]: | |
processed_text = text | |
else: | |
processed_text = _preprocess_text_for_tts(text, speaker) | |
deepgram_api_url = "https://api.deepgram.com/v1/speak" | |
params = {"model": "aura-asteria-en"} | |
if speaker == "John": | |
params["model"] = "aura-zeus-en" | |
headers = { | |
"Accept": "audio/mpeg", | |
"Content-Type": "application/json", | |
"Authorization": f"Token {os.environ.get('DEEPGRAM_API_KEY')}" | |
} | |
body = {"text": processed_text} | |
response = requests.post(deepgram_api_url, params=params, headers=headers, json=body, stream=True) | |
if response.status_code != 200: | |
raise ValueError(f"Deepgram TTS error: {response.status_code}, {response.text}") | |
content_type = response.headers.get('Content-Type', '') | |
if 'audio/mpeg' not in content_type: | |
raise ValueError("Unexpected Content-Type from Deepgram.") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as mp3_file: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
mp3_file.write(chunk) | |
mp3_path = mp3_file.name | |
audio_seg = AudioSegment.from_file(mp3_path, format="mp3") | |
audio_seg = effects.normalize(audio_seg) | |
final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name | |
audio_seg.export(final_mp3_path, format="mp3") | |
if os.path.exists(mp3_path): | |
os.remove(mp3_path) | |
return final_mp3_path | |
else: | |
print(f"[LOG] Using Murf API for language: {language_selection}") | |
if language_selection == "Hinglish": | |
from indic_transliteration.sanscript import transliterate, DEVANAGARI, IAST | |
text = transliterate(text, DEVANAGARI, IAST) | |
api_key = os.environ.get("MURF_API_KEY") | |
headers = { | |
"Content-Type": "application/json", | |
"Accept": "application/json", | |
"api-key": api_key | |
} | |
multi_native_locale = "hi-IN" if language_selection in ["Hinglish", "Hindi"] else "en-IN" | |
if language_selection == "English (Indian)": | |
voice_id = "en-IN-aarav" if speaker == "John" else "en-IN-isha" | |
elif language_selection == "Hindi": | |
voice_id = "hi-IN-kabir" if speaker == "John" else "hi-IN-shweta" | |
elif language_selection == "Hinglish": | |
voice_id = "hi-IN-kabir" if speaker == "John" else "hi-IN-shweta" | |
else: | |
voice_id = "en-IN-aarav" if speaker == "John" else "en-IN-isha" | |
payload = { | |
"audioDuration": 0, | |
"channelType": "MONO", | |
"encodeAsBase64": False, | |
"format": "WAV", | |
"modelVersion": "GEN2", | |
"multiNativeLocale": multi_native_locale, | |
"pitch": 0, | |
"pronunciationDictionary": {}, | |
"rate": 0, | |
"sampleRate": 48000, | |
"style": "Conversational", | |
"text": text, | |
"variation": 1, | |
"voiceId": voice_id | |
} | |
response = requests.post("https://api.murf.ai/v1/speech/generate", headers=headers, json=payload) | |
if response.status_code != 200: | |
raise ValueError(f"Murf API error: {response.status_code}, {response.text}") | |
json_resp = response.json() | |
audio_url = json_resp.get("audioFile") | |
if not audio_url: | |
raise ValueError("No audio file URL returned by Murf API") | |
audio_response = requests.get(audio_url) | |
if audio_response.status_code != 200: | |
raise ValueError(f"Error fetching audio from {audio_url}") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as wav_file: | |
wav_file.write(audio_response.content) | |
wav_path = wav_file.name | |
audio_seg = AudioSegment.from_file(wav_path, format="wav") | |
audio_seg = effects.normalize(audio_seg) | |
final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name | |
audio_seg.export(final_mp3_path, format="mp3") | |
os.remove(wav_path) | |
return final_mp3_path | |
except Exception as e: | |
print("[ERROR] Error generating audio:", e) | |
raise ValueError(f"Error generating audio: {str(e)}") | |
def transcribe_youtube_video_OLD_YTDLP(video_url: str) -> str: | |
pass | |
def _preprocess_text_for_tts(text: str, speaker: str) -> str: | |
text = re.sub(r"\bNo\.\b", "Number", text) | |
text = re.sub(r"\b(?i)SaaS\b", "sass", text) | |
abbreviations_as_words = {"NASA", "NATO", "UNESCO"} | |
def insert_periods_for_abbrev(m): | |
abbr = m.group(0) | |
if abbr in abbreviations_as_words: | |
return abbr | |
return ".".join(list(abbr)) + "." | |
text = re.sub(r"\b([A-Z]{2,})\b", insert_periods_for_abbrev, text) | |
text = re.sub(r"\.\.", ".", text) | |
def remove_periods_for_tts(m): | |
return m.group().replace(".", " ").strip() | |
text = re.sub(r"[A-Z]\.[A-Z](?:\.[A-Z])*\.", remove_periods_for_tts, text) | |
text = re.sub(r"-", " ", text) | |
text = re.sub(r"\b(ha(ha)?|heh|lol)\b", "(* laughs *)", text, flags=re.IGNORECASE) | |
text = re.sub(r"\bsigh\b", "(* sighs *)", text, flags=re.IGNORECASE) | |
text = re.sub(r"\b(groan|moan)\b", "(* groans *)", text, flags=re.IGNORECASE) | |
if speaker != "Jane": | |
def insert_thinking_pause(m): | |
word = m.group(1) | |
if random.random() < 0.3: | |
filler = random.choice(['hmm,', 'well,', 'let me see,']) | |
return f"{word}..., {filler}" | |
else: | |
return f"{word}...," | |
keywords_pattern = r"\b(important|significant|crucial|point|topic)\b" | |
text = re.sub(keywords_pattern, insert_thinking_pause, text, flags=re.IGNORECASE) | |
conj_pattern = r"\b(and|but|so|because|however)\b" | |
text = re.sub(conj_pattern, lambda m: f"{m.group()}...", text, flags=re.IGNORECASE) | |
text = re.sub(r"\b(uh|um|ah)\b", "", text, flags=re.IGNORECASE) | |
def capitalize_match(m): | |
return m.group().upper() | |
text = re.sub(r'(^\s*\w)|([.!?]\s*\w)', capitalize_match, text) | |
return text.strip() | |
def _spell_digits(d: str) -> str: | |
digit_map = { | |
'0': 'zero', '1': 'one', '2': 'two', '3': 'three', | |
'4': 'four', '5': 'five', '6': 'six', '7': 'seven', | |
'8': 'eight', '9': 'nine' | |
} | |
return " ".join(digit_map[ch] for ch in d if ch in digit_map) | |
def mix_with_bg_music(spoken: AudioSegment, custom_music_path=None) -> AudioSegment: | |
if custom_music_path: | |
music_path = custom_music_path | |
else: | |
music_path = "bg_music.mp3" | |
try: | |
bg_music = AudioSegment.from_file(music_path, format="mp3") | |
except Exception as e: | |
print("[ERROR] Failed to load background music:", e) | |
return spoken | |
bg_music = bg_music - 18.0 | |
total_length_ms = len(spoken) + 2000 | |
looped_music = AudioSegment.empty() | |
while len(looped_music) < total_length_ms: | |
looped_music += bg_music | |
looped_music = looped_music[:total_length_ms] | |
final_mix = looped_music.overlay(spoken, position=2000) | |
return final_mix | |
def call_groq_api_for_qa(system_prompt: str) -> str: | |
#Kept for use, Changed model | |
try: | |
headers = { | |
"Authorization": f"Bearer {os.environ.get('GROQ_API_KEY')}", # Use GROQ API KEY | |
"Content-Type": "application/json", | |
"Accept": "application/json" | |
} | |
data = { | |
"model": "deepseek-r1-distill-llama-70b", #Using Deepseek | |
"messages": [{"role": "user", "content": system_prompt}], | |
"max_tokens": 512, | |
"temperature": 0.7 | |
} | |
response = requests.post("https://api.groq.com/openai/v1/chat/completions", #Using groq endpoint | |
headers=headers, data=json.dumps(data)) | |
response.raise_for_status() | |
return response.json()["choices"][0]["message"]["content"].strip() | |
except Exception as e: | |
print("[ERROR] Groq API error:", e) | |
fallback = {"speaker": "John", "text": "I'm sorry, I'm having trouble answering right now."} | |
return json.dumps(fallback) | |
# --- Agent and Tavily Integration --- | |
def run_research_agent(topic: str, report_type: str = "research_report", max_results: int = 10) -> str: | |
""" | |
Runs the new research agent to generate a research report. This version uses | |
Tavily for search and Firecrawl for content extraction. | |
""" | |
print(f"[LOG] Starting research agent for topic: {topic}") | |
try: | |
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY")) | |
search_results = tavily_client.search(query=topic, max_results=max_results).results | |
if not search_results: | |
return "No relevant search results found." | |
print(f"[DEBUG] Tavily results: {search_results}") | |
# Use Firecrawl to scrape the content of each URL | |
combined_content = "" | |
for result in search_results: | |
url = result.url # Use dot notation to access attributes | |
print(f"[LOG] Scraping URL with Firecrawl: {url}") | |
headers = {'Authorization': f'Bearer {os.environ.get("FIRECRAWL_API_KEY")}'} | |
payload = {"url": url, "formats": ["markdown"], "onlyMainContent": True} | |
try: | |
response = requests.post("https://api.firecrawl.dev/v1/scrape", headers=headers, json=payload) | |
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx) | |
data = response.json() | |
# print(f"[DEBUG] Firecrawl response: {data}") #keep commented | |
if data.get('success') and 'markdown' in data.get('data', {}): | |
combined_content += data['data']['markdown'] + "\n\n" | |
else: | |
print(f"[WARNING] Firecrawl scrape failed or no markdown content for {url}: {data.get('error')}") | |
except requests.RequestException as e: | |
print(f"[ERROR] Error during Firecrawl request for {url}: {e}") | |
continue # Continue to the next URL | |
if not combined_content: | |
return "Could not retrieve content from any of the search results." | |
# Use Groq LLM to generate the report | |
prompt = f"""You are a world-class researcher, and you are tasked to write a comprehensive research report on the following topic: | |
{topic} | |
Use the following pieces of information, gathered from various web sources, to construct your report: | |
{combined_content} | |
Compile and synthesize the information to create a well-structured and informative research report. Include a title, introduction, main body with clearly defined sections, and a conclusion. Cite sources appropriately in the context. Do not hallucinate or make anything up. | |
""" | |
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
response = groq_client.chat.completions.create( | |
messages=[ | |
{"role": "user", "content": prompt} | |
], | |
model="deepseek-r1-distill-llama-70b", | |
temperature = 0.2 | |
) | |
report_text = response.choices[0].message.content | |
#print(f"[DEBUG] Raw report from LLM:\n{report_text}") #Keep commented out unless you have a very specific reason | |
structured_report = generate_report(report_text) # Use your report structuring function | |
return structured_report | |
except Exception as e: | |
print(f"[ERROR] Error in research agent: {e}") | |
return f"Sorry, I encountered an error during research: {e}" |