siddhartharyaai commited on
Commit
c0e2ca4
·
verified ·
1 Parent(s): aa1b4ab

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +468 -0
utils.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+
3
+ import os
4
+ import re
5
+ import json
6
+ import requests
7
+ import tempfile
8
+ from bs4 import BeautifulSoup
9
+ from typing import List, Literal
10
+ from pydantic import BaseModel
11
+ from pydub import AudioSegment, effects
12
+ from transformers import pipeline
13
+ import yt_dlp
14
+ import tiktoken
15
+ from groq import Groq # Ensure Groq client is imported
16
+ import numpy as np
17
+ import torch # Added to check CUDA availability
18
+
19
+ class DialogueItem(BaseModel):
20
+ speaker: Literal["Jane", "John"]
21
+ text: str
22
+
23
+ class Dialogue(BaseModel):
24
+ dialogue: List[DialogueItem]
25
+
26
+ # Initialize Whisper ASR pipeline
27
+ asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=0 if torch.cuda.is_available() else -1)
28
+
29
+ def truncate_text(text, max_tokens=2048):
30
+ print("[LOG] Truncating text if needed.")
31
+ tokenizer = tiktoken.get_encoding("cl100k_base")
32
+ tokens = tokenizer.encode(text)
33
+ if len(tokens) > max_tokens:
34
+ print("[LOG] Text too long, truncating.")
35
+ return tokenizer.decode(tokens[:max_tokens])
36
+ return text
37
+
38
+ def extract_text_from_url(url):
39
+ print("[LOG] Extracting text from URL:", url)
40
+ try:
41
+ response = requests.get(url)
42
+ if response.status_code != 200:
43
+ print(f"[ERROR] Failed to fetch URL: {url} with status code {response.status_code}")
44
+ return ""
45
+ soup = BeautifulSoup(response.text, 'html.parser')
46
+ for script in soup(["script", "style"]):
47
+ script.decompose()
48
+ text = soup.get_text(separator=' ')
49
+ print("[LOG] Text extraction from URL successful.")
50
+ return text
51
+ except Exception as e:
52
+ print(f"[ERROR] Exception during text extraction from URL: {e}")
53
+ return ""
54
+
55
+ def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment:
56
+ """
57
+ Shifts the pitch of an AudioSegment by a given number of semitones.
58
+ Positive semitones shift the pitch up, negative shift it down.
59
+ """
60
+ print(f"[LOG] Shifting pitch by {semitones} semitones.")
61
+ new_sample_rate = int(audio.frame_rate * (2.0 ** (semitones / 12.0)))
62
+ shifted_audio = audio._spawn(audio.raw_data, overrides={'frame_rate': new_sample_rate})
63
+ return shifted_audio.set_frame_rate(audio.frame_rate)
64
+
65
+ def is_sufficient(text: str, min_word_count: int = 500) -> bool:
66
+ """
67
+ Determines if the fetched information meets the sufficiency criteria.
68
+
69
+ :param text: Aggregated text from primary sources.
70
+ :param min_word_count: Minimum number of words required.
71
+ :return: True if sufficient, False otherwise.
72
+ """
73
+ word_count = len(text.split())
74
+ print(f"[DEBUG] Aggregated word count: {word_count}")
75
+ return word_count >= min_word_count
76
+
77
+ def query_llm_for_additional_info(topic: str, existing_text: str) -> str:
78
+ """
79
+ Queries the Groq API to retrieve additional relevant information from the LLM's knowledge base.
80
+
81
+ :param topic: The research topic.
82
+ :param existing_text: The text already gathered from primary sources.
83
+ :return: Additional relevant information as a string.
84
+ """
85
+ print("[LOG] Querying LLM for additional information.")
86
+ # Define the system prompt for the LLM
87
+ system_prompt = (
88
+ "You are an AI assistant with extensive knowledge up to 2023-10. "
89
+ "Provide additional relevant information on the following topic based on your knowledge base.\n\n"
90
+ f"Topic: {topic}\n\n"
91
+ f"Existing Information: {existing_text}\n\n"
92
+ "Please add more insightful details, facts, and perspectives to enhance the understanding of the topic."
93
+ )
94
+
95
+ groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
96
+
97
+ try:
98
+ response = groq_client.chat.completions.create(
99
+ messages=[{"role": "system", "content": system_prompt}],
100
+ model="llama-3.3-70b-versatile",
101
+ max_tokens=1024,
102
+ temperature=0.7
103
+ )
104
+ except Exception as e:
105
+ print("[ERROR] Groq API error during fallback:", e)
106
+ return ""
107
+
108
+ additional_info = response.choices[0].message.content.strip()
109
+ print("[DEBUG] Additional information from LLM:")
110
+ print(additional_info)
111
+ return additional_info
112
+
113
+ def research_topic(topic: str) -> str:
114
+ # Sources:
115
+ sources = {
116
+ "BBC": "https://feeds.bbci.co.uk/news/rss.xml",
117
+ "CNN": "http://rss.cnn.com/rss/edition.rss",
118
+ "Associated Press": "https://apnews.com/apf-topnews",
119
+ "NDTV": "https://www.ndtv.com/rss/top-stories",
120
+ "Times of India": "https://timesofindia.indiatimes.com/rssfeeds/296589292.cms",
121
+ "The Hindu": "https://www.thehindu.com/news/national/kerala/rssfeed.xml",
122
+ "Economic Times": "https://economictimes.indiatimes.com/rssfeeds/1977021501.cms",
123
+ "Google News - Custom": f"https://news.google.com/rss/search?q={requests.utils.quote(topic)}&hl=en-IN&gl=IN&ceid=IN:en",
124
+ }
125
+
126
+ summary_parts = []
127
+
128
+ # Wikipedia summary
129
+ wiki_summary = fetch_wikipedia_summary(topic)
130
+ if wiki_summary:
131
+ summary_parts.append(f"From Wikipedia: {wiki_summary}")
132
+
133
+ # For each news RSS
134
+ for name, url in sources.items():
135
+ try:
136
+ items = fetch_rss_feed(url)
137
+ if not items:
138
+ continue
139
+ # Use simple keyword matching
140
+ title, desc, link = find_relevant_article(items, topic, min_match=2)
141
+ if link:
142
+ article_text = fetch_article_text(link)
143
+ if article_text:
144
+ summary_parts.append(f"From {name}: {article_text}")
145
+ else:
146
+ # If no main text extracted, use title/desc
147
+ summary_parts.append(f"From {name}: {title} - {desc}")
148
+ except Exception as e:
149
+ print(f"[ERROR] Error fetching from {name} RSS feed:", e)
150
+ continue
151
+
152
+ aggregated_info = " ".join(summary_parts)
153
+ print("[DEBUG] Aggregated information from primary sources.")
154
+ print(aggregated_info)
155
+
156
+ if not is_sufficient(aggregated_info):
157
+ print("[LOG] Insufficient information from primary sources. Initiating fallback to LLM.")
158
+ additional_info = query_llm_for_additional_info(topic, aggregated_info)
159
+ if additional_info:
160
+ aggregated_info += " " + additional_info
161
+ else:
162
+ print("[ERROR] Failed to retrieve additional information from LLM.")
163
+
164
+ if not aggregated_info:
165
+ # No info found at all
166
+ print("[LOG] No information found for the topic.")
167
+ return f"Sorry, I couldn't find recent information on '{topic}'."
168
+
169
+ return aggregated_info
170
+
171
+ def fetch_wikipedia_summary(topic: str) -> str:
172
+ print("[LOG] Fetching Wikipedia summary for:", topic)
173
+ try:
174
+ # 1. Search for the topic
175
+ search_url = f"https://en.wikipedia.org/w/api.php?action=opensearch&search={requests.utils.quote(topic)}&limit=1&namespace=0&format=json"
176
+ resp = requests.get(search_url)
177
+ if resp.status_code != 200:
178
+ print(f"[ERROR] Failed to fetch Wikipedia search results for topic: {topic}")
179
+ return ""
180
+ data = resp.json()
181
+ if len(data) > 1 and data[1]:
182
+ title = data[1][0]
183
+ # 2. Fetch summary
184
+ summary_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{requests.utils.quote(title)}"
185
+ s_resp = requests.get(summary_url)
186
+ if s_resp.status_code == 200:
187
+ s_data = s_resp.json()
188
+ if "extract" in s_data:
189
+ print("[LOG] Wikipedia summary fetched successfully.")
190
+ return s_data["extract"]
191
+ print("[LOG] No Wikipedia summary found for topic:", topic)
192
+ return ""
193
+ except Exception as e:
194
+ print(f"[ERROR] Exception during Wikipedia summary fetch: {e}")
195
+ return ""
196
+
197
+ def fetch_rss_feed(feed_url: str) -> list:
198
+ print("[LOG] Fetching RSS feed:", feed_url)
199
+ try:
200
+ resp = requests.get(feed_url)
201
+ if resp.status_code != 200:
202
+ print(f"[ERROR] Failed to fetch RSS feed: {feed_url} with status code {resp.status_code}")
203
+ return []
204
+ # Use html.parser instead of xml to avoid needing lxml or other parsers.
205
+ soup = BeautifulSoup(resp.content, "html.parser")
206
+ items = soup.find_all("item")
207
+ print(f"[LOG] Number of items fetched from {feed_url}: {len(items)}")
208
+ return items
209
+ except Exception as e:
210
+ print(f"[ERROR] Exception occurred while fetching RSS feed {feed_url}: {e}")
211
+ return []
212
+
213
+ def find_relevant_article(items, topic: str, min_match=2) -> tuple:
214
+ """
215
+ Searches for relevant articles based on topic keywords.
216
+ :param items: List of RSS feed items
217
+ :param topic: Topic string
218
+ :param min_match: Minimum number of keyword matches required
219
+ :return: (title, description, link) or (None, None, None)
220
+ """
221
+ print("[LOG] Finding relevant articles...")
222
+ keywords = re.findall(r'\w+', topic.lower())
223
+ print(f"[LOG] Topic keywords: {keywords}")
224
+
225
+ for item in items:
226
+ title = item.find("title").get_text().strip() if item.find("title") else ""
227
+ description = item.find("description").get_text().strip() if item.find("description") else ""
228
+ text = f"{title.lower()} {description.lower()}"
229
+ matches = sum(1 for kw in keywords if kw in text)
230
+ print(f"[DEBUG] Checking article: '{title}' | Matches: {matches}/{len(keywords)}")
231
+ if matches >= min_match:
232
+ link = item.find("link").get_text().strip() if item.find("link") else ""
233
+ print(f"[LOG] Relevant article found: {title}")
234
+ return title, description, link
235
+ print("[LOG] No relevant articles found based on the current matching criteria.")
236
+ return None, None, None
237
+
238
+ def fetch_article_text(link: str) -> str:
239
+ print("[LOG] Fetching article text from:", link)
240
+ if not link:
241
+ print("[LOG] No link provided for fetching article text.")
242
+ return ""
243
+ try:
244
+ resp = requests.get(link)
245
+ if resp.status_code != 200:
246
+ print(f"[ERROR] Failed to fetch article from link: {link} with status code {resp.status_code}")
247
+ return ""
248
+ soup = BeautifulSoup(resp.text, 'html.parser')
249
+ # This is site-specific. We'll try a generic approach:
250
+ # Just take all paragraphs:
251
+ paragraphs = soup.find_all("p")
252
+ text = " ".join(p.get_text() for p in paragraphs[:5]) # first 5 paragraphs for more context
253
+ print("[LOG] Article text fetched successfully.")
254
+ return text.strip()
255
+ except Exception as e:
256
+ print(f"[ERROR] Error fetching article text: {e}")
257
+ return ""
258
+
259
+ def generate_script(system_prompt: str, input_text: str, tone: str, target_length: str):
260
+ print("[LOG] Generating script with tone:", tone, "and length:", target_length)
261
+ groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
262
+
263
+ # Map target_length to word ranges
264
+ length_mapping = {
265
+ "1-3 Mins": (200, 450),
266
+ "3-5 Mins": (450, 750),
267
+ "5-10 Mins": (750, 1500),
268
+ "10-20 Mins": (1500, 3000)
269
+ }
270
+ min_words, max_words = length_mapping.get(target_length, (200, 450))
271
+
272
+ # Adjust tone description for clarity in prompt
273
+ tone_description = {
274
+ "Humorous": "funny and exciting, makes people chuckle",
275
+ "Formal": "business-like, well-structured, professional",
276
+ "Casual": "like a conversation between close friends, relaxed and informal",
277
+ "Youthful": "like how teenagers might chat, energetic and lively"
278
+ }
279
+
280
+ chosen_tone = tone_description.get(tone, "casual")
281
+
282
+ # Construct the prompt with clear instructions for JSON output
283
+ prompt = (
284
+ f"{system_prompt}\n"
285
+ f"TONE: {chosen_tone}\n"
286
+ f"TARGET LENGTH: {target_length} ({min_words}-{max_words} words)\n"
287
+ f"INPUT TEXT: {input_text}\n\n"
288
+ "Please provide the output in the following JSON format without any additional text:\n\n"
289
+ "{\n"
290
+ ' "dialogue": [\n'
291
+ ' {\n'
292
+ ' "speaker": "Jane",\n'
293
+ ' "text": "..." \n'
294
+ ' },\n'
295
+ ' {\n'
296
+ ' "speaker": "John",\n'
297
+ ' "text": "..." \n'
298
+ ' }\n'
299
+ " ]\n"
300
+ "}"
301
+ )
302
+ print("[LOG] Sending prompt to Groq:")
303
+ print(prompt) # Log the prompt being sent
304
+
305
+ try:
306
+ response = groq_client.chat.completions.create(
307
+ messages=[{"role": "system", "content": prompt}],
308
+ model="llama-3.3-70b-versatile",
309
+ max_tokens=2048,
310
+ temperature=0.7
311
+ )
312
+ except Exception as e:
313
+ print("[ERROR] Groq API error:", e)
314
+ raise ValueError(f"Error communicating with Groq API: {str(e)}")
315
+
316
+ # Log the raw response content for debugging
317
+ raw_content = response.choices[0].message.content.strip()
318
+ print("[DEBUG] Raw API response content:")
319
+ print(raw_content)
320
+
321
+ # Attempt to extract JSON from the response
322
+ content = raw_content.replace('```json', '').replace('```', '').strip()
323
+
324
+ start_index = content.find('{')
325
+ end_index = content.rfind('}')
326
+
327
+ if start_index == -1 or end_index == -1:
328
+ print("[ERROR] Failed to parse dialogue. No JSON found.")
329
+ print("[ERROR] Entire response content:")
330
+ print(content)
331
+ raise ValueError("Failed to parse dialogue: Could not find JSON object in response.")
332
+
333
+ json_str = content[start_index:end_index+1].strip()
334
+
335
+ print("[DEBUG] Extracted JSON string:")
336
+ print(json_str)
337
+
338
+ try:
339
+ data = json.loads(json_str)
340
+ print("[LOG] Script generated successfully.")
341
+ return Dialogue(**data)
342
+ except json.JSONDecodeError as e:
343
+ print("[ERROR] JSON decoding failed:", e)
344
+ print("[ERROR] Response content causing failure:")
345
+ print(content)
346
+ raise ValueError(f"Failed to parse dialogue: {str(e)}")
347
+
348
+ def generate_audio_mp3(text: str, speaker: str) -> str:
349
+ try:
350
+ print(f"[LOG] Generating audio for speaker: {speaker}")
351
+
352
+ # Define Deepgram API endpoint
353
+ deepgram_api_url = "https://api.deepgram.com/v1/speak"
354
+
355
+ # Prepare query parameters
356
+ params = {
357
+ "model": "aura-asteria-en", # Default model; adjust if needed
358
+ # You can add more parameters here as needed, e.g., bit_rate, sample_rate, etc.
359
+ }
360
+
361
+ # Override model if needed based on speaker
362
+ if speaker == "Jane":
363
+ params["model"] = "aura-asteria-en" # Female voice
364
+ elif speaker == "John":
365
+ params["model"] = "aura-perseus-en" # Male voice
366
+ else:
367
+ raise ValueError(f"Unknown speaker: {speaker}")
368
+
369
+ # Prepare headers
370
+ headers = {
371
+ "Accept": "audio/mpeg", # Request MP3 files
372
+ "Content-Type": "application/json",
373
+ "Authorization": f"Token {os.environ.get('DEEPGRAM_API_KEY')}"
374
+ }
375
+
376
+ # Prepare body
377
+ body = {
378
+ "text": text
379
+ }
380
+
381
+ print("[LOG] Sending TTS request to Deepgram...")
382
+ # Make the POST request to Deepgram's TTS API
383
+ response = requests.post(deepgram_api_url, params=params, headers=headers, json=body, stream=True)
384
+
385
+ if response.status_code != 200:
386
+ print(f"[ERROR] Deepgram TTS API returned status code {response.status_code}: {response.text}")
387
+ raise ValueError(f"Deepgram TTS API error: {response.status_code} - {response.text}")
388
+
389
+ # Verify Content-Type
390
+ content_type = response.headers.get('Content-Type', '')
391
+ if 'audio/mpeg' not in content_type:
392
+ print("[ERROR] Unexpected Content-Type received from Deepgram:", content_type)
393
+ print("[ERROR] Response content:", response.text)
394
+ raise ValueError("Unexpected Content-Type received from Deepgram.")
395
+
396
+ # Save the streamed audio to a temporary MP3 file
397
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as mp3_file:
398
+ for chunk in response.iter_content(chunk_size=8192):
399
+ if chunk:
400
+ mp3_file.write(chunk)
401
+ mp3_temp_path = mp3_file.name
402
+ print(f"[LOG] Audio received from Deepgram and saved at: {mp3_temp_path}")
403
+
404
+ # Normalize audio volume
405
+ audio_seg = AudioSegment.from_file(mp3_temp_path, format="mp3")
406
+ audio_seg = effects.normalize(audio_seg)
407
+
408
+ # Removed pitch shifting for male voice
409
+ # Previously:
410
+ # if speaker == "John":
411
+ # semitones = -5 # Shift down by 5 semitones for a deeper voice
412
+ # audio_seg = pitch_shift(audio_seg, semitones=semitones)
413
+ # print(f"[LOG] Applied pitch shift to John's voice by {semitones} semitones.")
414
+
415
+ # Export the final audio as MP3
416
+ final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
417
+ audio_seg.export(final_mp3_path, format="mp3")
418
+ print("[LOG] Audio post-processed and saved at:", final_mp3_path)
419
+
420
+ # Clean up the initial MP3 file
421
+ if os.path.exists(mp3_temp_path):
422
+ os.remove(mp3_temp_path)
423
+ print(f"[LOG] Removed temporary MP3 file: {mp3_temp_path}")
424
+
425
+ return final_mp3_path
426
+ except Exception as e:
427
+ print("[ERROR] Error generating audio:", e)
428
+ raise ValueError(f"Error generating audio: {str(e)}")
429
+
430
+ def transcribe_youtube_video(video_url: str) -> str:
431
+ print("[LOG] Transcribing YouTube video:", video_url)
432
+ fd, audio_file = tempfile.mkstemp(suffix=".wav")
433
+ os.close(fd)
434
+
435
+ ydl_opts = {
436
+ 'format': 'bestaudio/best',
437
+ 'outtmpl': audio_file,
438
+ 'postprocessors': [{
439
+ 'key': 'FFmpegExtractAudio',
440
+ 'preferredcodec': 'wav',
441
+ 'preferredquality': '192'
442
+ }],
443
+ 'quiet': True,
444
+ 'no_warnings': True,
445
+ }
446
+
447
+ try:
448
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
449
+ ydl.download([video_url])
450
+ except yt_dlp.utils.DownloadError as e:
451
+ print("[ERROR] yt-dlp download error:", e)
452
+ raise ValueError(f"Error downloading YouTube video: {str(e)}")
453
+
454
+ print("[LOG] Audio downloaded at:", audio_file)
455
+ try:
456
+ # Run ASR on the downloaded audio
457
+ result = asr_pipeline(audio_file)
458
+ transcript = result["text"]
459
+ print("[LOG] Transcription completed.")
460
+ return transcript.strip()
461
+ except Exception as e:
462
+ print("[ERROR] ASR transcription error:", e)
463
+ raise ValueError(f"Error transcribing YouTube video: {str(e)}")
464
+ finally:
465
+ # Clean up the downloaded audio file
466
+ if os.path.exists(audio_file):
467
+ os.remove(audio_file)
468
+ print(f"[LOG] Removed temporary audio file: {audio_file}")