Spaces:
Running
Running
Update utils.py
Browse files
utils.py
CHANGED
@@ -7,35 +7,17 @@ from bs4 import BeautifulSoup
|
|
7 |
from typing import List, Literal, Optional
|
8 |
from pydantic import BaseModel
|
9 |
from pydub import AudioSegment, effects
|
10 |
-
import
|
11 |
import tiktoken
|
12 |
-
from groq import Groq
|
13 |
import numpy as np
|
14 |
-
import torch
|
15 |
-
from transformers import pipeline # Moved to the top, since it's used before other things
|
16 |
import random
|
17 |
-
from tavily import TavilyClient #Moved
|
18 |
-
from report_structure import generate_report # Import report structure
|
19 |
-
|
20 |
-
# --- Add the cloned repository to the Python path ---
|
21 |
-
repo_path = os.path.join('/home', 'user', 'open_deep_research')
|
22 |
-
print(f"DEBUG: repo_path = {repo_path}") # Debug print - keep this for now
|
23 |
-
if repo_path not in sys.path:
|
24 |
-
print("DEBUG: Adding repo_path to sys.path") # Debug print - keep this
|
25 |
-
sys.path.insert(0, repo_path)
|
26 |
-
else:
|
27 |
-
print("DEBUG: repo_path already in sys.path") # Debug print - keep this for now
|
28 |
-
print(f"DEBUG: sys.path = {sys.path}") # Debug print - keep this for now
|
29 |
-
|
30 |
-
# --- CORRECT IMPORT (for local cloned repo) ---
|
31 |
-
try:
|
32 |
-
from open_deep_research.agent import OpenDeepResearchAgent
|
33 |
-
print("DEBUG: Import successful!")
|
34 |
-
except ImportError as e:
|
35 |
-
print(f"DEBUG: Import failed: {e}")
|
36 |
-
raise
|
37 |
-
from report_structure import generate_report
|
38 |
|
|
|
|
|
|
|
|
|
39 |
|
40 |
class DialogueItem(BaseModel):
|
41 |
speaker: Literal["Jane", "John"]
|
@@ -60,28 +42,6 @@ def truncate_text(text, max_tokens=2048):
|
|
60 |
return tokenizer.decode(tokens[:max_tokens])
|
61 |
return text
|
62 |
|
63 |
-
def extract_text_from_url(url):
|
64 |
-
# This function is retained for potential edge cases.
|
65 |
-
print("[LOG] Extracting text from URL (fallback method):", url)
|
66 |
-
try:
|
67 |
-
headers = {
|
68 |
-
"User-Agent": ("Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
69 |
-
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
70 |
-
"Chrome/115.0.0.0 Safari/537.36")
|
71 |
-
}
|
72 |
-
response = requests.get(url, headers=headers)
|
73 |
-
if response.status_code != 200:
|
74 |
-
print(f"[ERROR] Failed to fetch URL: {url} with status code {response.status_code}")
|
75 |
-
return ""
|
76 |
-
soup = BeautifulSoup(response.text, 'html.parser')
|
77 |
-
for script in soup(["script", "style"]):
|
78 |
-
script.decompose()
|
79 |
-
text = soup.get_text(separator=' ')
|
80 |
-
print("[LOG] Text extraction from URL (fallback) successful.")
|
81 |
-
return text
|
82 |
-
except Exception as e:
|
83 |
-
print(f"[ERROR] Exception during text extraction from URL (fallback): {e}")
|
84 |
-
return ""
|
85 |
|
86 |
def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment:
|
87 |
print(f"[LOG] Shifting pitch by {semitones} semitones.")
|
@@ -89,34 +49,15 @@ def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment:
|
|
89 |
shifted_audio = audio._spawn(audio.raw_data, overrides={'frame_rate': new_sample_rate})
|
90 |
return shifted_audio.set_frame_rate(audio.frame_rate)
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
def
|
99 |
-
|
100 |
-
pass
|
101 |
-
def research_topic(topic: str) -> str:
|
102 |
-
# No longer needed
|
103 |
-
pass
|
104 |
-
|
105 |
-
def fetch_wikipedia_summary(topic: str) -> str:
|
106 |
-
# No longer needed
|
107 |
-
pass
|
108 |
|
109 |
-
def fetch_rss_feed(feed_url: str) -> list:
|
110 |
-
# No longer needed
|
111 |
-
pass
|
112 |
-
|
113 |
-
def find_relevant_article(items, topic: str, min_match=2) -> tuple:
|
114 |
-
# No longer needed
|
115 |
-
pass
|
116 |
-
|
117 |
-
def fetch_article_text(link: str) -> str:
|
118 |
-
# No longer needed
|
119 |
-
pass
|
120 |
|
121 |
def generate_script(
|
122 |
system_prompt: str,
|
@@ -129,7 +70,7 @@ def generate_script(
|
|
129 |
sponsor_provided=None
|
130 |
):
|
131 |
print("[LOG] Generating script with tone:", tone, "and length:", target_length)
|
132 |
-
import streamlit as st
|
133 |
if (host_name == "Jane" or not host_name) and st.session_state.get("language_selection") in ["English (Indian)", "Hinglish", "Hindi"]:
|
134 |
host_name = "Isha"
|
135 |
if (guest_name == "John" or not guest_name) and st.session_state.get("language_selection") in ["English (Indian)", "Hinglish", "Hindi"]:
|
@@ -483,73 +424,73 @@ def generate_script(
|
|
483 |
return json.dumps(fallback)
|
484 |
|
485 |
# --- Agent and Tavily Integration ---
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
|
|
7 |
from typing import List, Literal, Optional
|
8 |
from pydantic import BaseModel
|
9 |
from pydub import AudioSegment, effects
|
10 |
+
from transformers import pipeline
|
11 |
import tiktoken
|
12 |
+
from groq import Groq # Retained for LLM interaction
|
13 |
import numpy as np
|
14 |
+
import torch
|
|
|
15 |
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
# --- CORRECT IMPORTS ---
|
18 |
+
# No more sys.path modification!
|
19 |
+
from report_structure import generate_report # For report structuring
|
20 |
+
from tavily import TavilyClient
|
21 |
|
22 |
class DialogueItem(BaseModel):
|
23 |
speaker: Literal["Jane", "John"]
|
|
|
42 |
return tokenizer.decode(tokens[:max_tokens])
|
43 |
return text
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment:
|
47 |
print(f"[LOG] Shifting pitch by {semitones} semitones.")
|
|
|
49 |
shifted_audio = audio._spawn(audio.raw_data, overrides={'frame_rate': new_sample_rate})
|
50 |
return shifted_audio.set_frame_rate(audio.frame_rate)
|
51 |
|
52 |
+
# --- Functions no longer needed ---
|
53 |
+
# def is_sufficient(...)
|
54 |
+
# def query_llm_for_additional_info(...)
|
55 |
+
# def research_topic(...)
|
56 |
+
# def fetch_wikipedia_summary(...)
|
57 |
+
# def fetch_rss_feed(...)
|
58 |
+
# def find_relevant_article(...)
|
59 |
+
# def fetch_article_text(...)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
def generate_script(
|
63 |
system_prompt: str,
|
|
|
70 |
sponsor_provided=None
|
71 |
):
|
72 |
print("[LOG] Generating script with tone:", tone, "and length:", target_length)
|
73 |
+
import streamlit as st # Import streamlit here, where it's used
|
74 |
if (host_name == "Jane" or not host_name) and st.session_state.get("language_selection") in ["English (Indian)", "Hinglish", "Hindi"]:
|
75 |
host_name = "Isha"
|
76 |
if (guest_name == "John" or not guest_name) and st.session_state.get("language_selection") in ["English (Indian)", "Hinglish", "Hindi"]:
|
|
|
424 |
return json.dumps(fallback)
|
425 |
|
426 |
# --- Agent and Tavily Integration ---
|
427 |
+
def run_research_agent(topic: str, report_type: str = "research_report", max_results: int = 10) -> str:
|
428 |
+
"""
|
429 |
+
Runs the new research agent to generate a research report. This version uses
|
430 |
+
Tavily for search and Firecrawl for content extraction.
|
431 |
+
"""
|
432 |
+
print(f"[LOG] Starting research agent for topic: {topic}")
|
433 |
+
try:
|
434 |
+
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY"))
|
435 |
+
search_results = tavily_client.search(query=topic, max_results=max_results).results
|
436 |
+
|
437 |
+
if not search_results:
|
438 |
+
return "No relevant search results found."
|
439 |
+
|
440 |
+
print(f"[DEBUG] Tavily results: {search_results}")
|
441 |
+
|
442 |
+
# Use Firecrawl to scrape the content of each URL
|
443 |
+
combined_content = ""
|
444 |
+
for result in search_results:
|
445 |
+
url = result.url # Use dot notation to access attributes
|
446 |
+
print(f"[LOG] Scraping URL with Firecrawl: {url}")
|
447 |
+
headers = {'Authorization': f'Bearer {os.environ.get("FIRECRAWL_API_KEY")}'}
|
448 |
+
payload = {"url": url, "formats": ["markdown"], "onlyMainContent": True}
|
449 |
+
try:
|
450 |
+
response = requests.post("https://api.firecrawl.dev/v1/scrape", headers=headers, json=payload)
|
451 |
+
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
|
452 |
+
data = response.json()
|
453 |
+
# print(f"[DEBUG] Firecrawl response: {data}") #keep commented
|
454 |
+
|
455 |
+
if data.get('success') and 'markdown' in data.get('data', {}):
|
456 |
+
combined_content += data['data']['markdown'] + "\n\n"
|
457 |
+
else:
|
458 |
+
print(f"[WARNING] Firecrawl scrape failed or no markdown content for {url}: {data.get('error')}")
|
459 |
+
|
460 |
+
except requests.RequestException as e:
|
461 |
+
print(f"[ERROR] Error during Firecrawl request for {url}: {e}")
|
462 |
+
continue # Continue to the next URL
|
463 |
+
|
464 |
+
if not combined_content:
|
465 |
+
return "Could not retrieve content from any of the search results."
|
466 |
+
|
467 |
+
# Use Groq LLM to generate the report
|
468 |
+
prompt = f"""You are a world-class researcher, and you are tasked to write a comprehensive research report on the following topic:
|
469 |
+
|
470 |
+
{topic}
|
471 |
+
|
472 |
+
Use the following pieces of information, gathered from various web sources, to construct your report:
|
473 |
+
|
474 |
+
{combined_content}
|
475 |
+
|
476 |
+
Compile and synthesize the information to create a well-structured and informative research report. Include a title, introduction, main body with clearly defined sections, and a conclusion. Cite sources appropriately in the context. Do not hallucinate or make anything up.
|
477 |
+
"""
|
478 |
+
|
479 |
+
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
480 |
+
response = groq_client.chat.completions.create(
|
481 |
+
messages=[
|
482 |
+
{"role": "user", "content": prompt}
|
483 |
+
],
|
484 |
+
model="deepseek-r1-distill-llama-70b",
|
485 |
+
temperature = 0.2
|
486 |
+
)
|
487 |
+
report_text = response.choices[0].message.content
|
488 |
+
#print(f"[DEBUG] Raw report from LLM:\n{report_text}") #Keep commented out unless you have a very specific reason
|
489 |
+
|
490 |
+
structured_report = generate_report(report_text) # Use your report structuring function
|
491 |
+
return structured_report
|
492 |
+
|
493 |
+
|
494 |
+
except Exception as e:
|
495 |
+
print(f"[ERROR] Error in research agent: {e}")
|
496 |
+
return f"Sorry, I encountered an error during research: {e}"
|