File size: 7,885 Bytes
2f77fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os
import re
from datetime import datetime
from typing import Annotated

from dotenv import load_dotenv
from pydantic import BaseModel, Field

from langchain_core.messages import SystemMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI

from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph_supervisor.supervisor import create_supervisor

from youtube_transcript_api import (
    NoTranscriptFound,
    TranscriptsDisabled,
    VideoUnavailable,
    YouTubeTranscriptApi,
)

from prompts import WEB_SEARCH_PROMPT, YOUTUBE_PROMPT, MULTIMODAL_PROMPT    

# Load environment variables from .env file
load_dotenv()

# Initialize OpenAI LLM (gpt-4o) for general and web search tasks
openai_llm = ChatOpenAI(
    model="gpt-4o",
    use_responses_api=True,
    api_key=os.getenv("OPENAI_API_KEY")
)

# Initialize Google Gemini LLM for YouTube and multimodal tasks
google_llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash-preview-04-17",
    google_api_key=os.getenv("GOOGLE_API_KEY"),
)

class AgentState(MessagesState):
    """
    State class for agent workflows, tracks the message history.
    """
    messages: Annotated[list, add_messages]

class YouTubeTranscriptInput(BaseModel):
    """
    Input schema for the YouTube transcript tool.
    """
    video_url: str = Field(description="YouTube URL or video ID.")
    raw: bool = Field(default=False, description="Include timestamps?")

@tool("youtube_transcript", args_schema=YouTubeTranscriptInput)
def youtube_transcript(video_url: str, raw: bool = False) -> str:
    """
    Fetches the transcript of a YouTube video given its URL or ID.
    Returns plain text (no timestamps) or raw with timestamps.
    """
    # Extract video ID from URL or use as-is if already an ID
    if "youtube.com" in video_url or "youtu.be" in video_url:
        match = re.search(r"(?:v=|youtu.be/)([\w-]{11})", video_url)
        if not match:
            return "Invalid YouTube URL or ID."
        video_id = match.group(1)
    else:
        video_id = video_url.strip()
    try:
        # Fetch transcript using the API
        transcript = YouTubeTranscriptApi.get_transcript(video_id)
        if raw:
            # Return transcript with timestamps
            return "\n".join(f"{int(e['start'])}s: {e['text']}" for e in transcript)
        # Return plain transcript text
        return " ".join(e['text'] for e in transcript)
    except TranscriptsDisabled:
        return "Transcripts are disabled for this video."
    except NoTranscriptFound:
        return "No transcript found for this video."
    except VideoUnavailable:
        return "This video is unavailable."
    except Exception as e:
        return f"An error occurred while fetching the transcript: {e}"

# List of available tools for the agent (currently only YouTube transcript)
tools = [youtube_transcript]

def create_web_search_graph() -> StateGraph:
    """
    Create the web search agent graph.

    Returns:
        StateGraph: The compiled web search agent workflow.
    """
    web_search_preview = [{"type": "web_search_preview"}]
    # Bind the web search tool to the OpenAI LLM
    llm_with_tools = openai_llm.bind_tools(web_search_preview)

    def agent_node(state: AgentState) -> dict:
        """
        Node function for handling web search queries.

        Args:
            state (AgentState): The current agent state.

        Returns:
            dict: Updated state with the LLM response.
        """
        current_date = datetime.now().strftime("%B %d, %Y")
        # Format the system prompt with the current date
        system_message = SystemMessage(content=WEB_SEARCH_PROMPT.format(current_date=current_date))
        # Re-bind tools for each invocation (defensive)
        web_search_preview = [{"type": "web_search_preview"}]
        response = llm_with_tools.bind_tools(web_search_preview).invoke(
            [system_message] + state.get("messages")
        )
        return {"messages": state.get("messages") + [response]}

    # Build the workflow graph
    workflow = StateGraph(AgentState)
    workflow.add_node("agent", agent_node)
    workflow.add_edge(START, "agent")
    workflow.add_edge("agent", END)
    return workflow.compile(name="web_search_agent")

def create_youtube_viwer_graph() -> StateGraph:
    """
    Create the YouTube viewer agent graph.

    Returns:
        StateGraph: The compiled YouTube viewer agent workflow.
    """
    def agent_node(state: AgentState) -> dict:
        """
        Node function for handling YouTube-related queries.

        Args:
            state (AgentState): The current agent state.

        Returns:
            dict: Updated state with the LLM response.
        """
        current_date = datetime.now().strftime("%B %d, %Y")
        # Format the system prompt with the current date
        system_message = SystemMessage(content=YOUTUBE_PROMPT.format(current_date=current_date))
        # Bind the YouTube transcript tool to the Gemini LLM
        llm_with_tools = google_llm.bind_tools(tools)
        response = llm_with_tools.invoke([system_message] + state.get("messages"))
        return {"messages": state.get("messages") + [response]}

    # Build the workflow graph with tool node and conditional routing
    workflow = StateGraph(AgentState)
    workflow.add_node("llm", agent_node)
    workflow.add_node("tools", ToolNode(tools))
    workflow.set_entry_point("llm")
    workflow.add_conditional_edges(
        "llm",
        tools_condition,
        {
            "tools": "tools",   # If tool is needed, go to tools node
            "__end__": END,     # Otherwise, end the workflow
        },
    )
    workflow.add_edge("tools", "llm")  # After tool, return to LLM node
    return workflow.compile(name="youtube_viwer_agent")

def create_multimodal_agent_graph() -> StateGraph:
    """
    Create the multimodal agent graph using Gemini for best multimodal support.

    Returns:
        StateGraph: The compiled multimodal agent workflow.
    """
    def agent_node(state: AgentState) -> dict:
        """
        Node function for handling multimodal queries.

        Args:
            state (AgentState): The current agent state.

        Returns:
            dict: Updated state with the LLM response.
        """
        current_date = datetime.now().strftime("%B %d, %Y")
        # Compose the system message with the multimodal prompt and current date
        system_message = SystemMessage(content=MULTIMODAL_PROMPT + f" Today's date: {current_date}.")
        messages = [system_message] + state.get("messages")
        # Invoke Gemini LLM for multimodal reasoning
        response = google_llm.invoke(messages)
        return {"messages": state.get("messages") + [response]}

    # Build the workflow graph
    workflow = StateGraph(AgentState)
    workflow.add_node("agent", agent_node)
    workflow.add_edge(START, "agent")
    workflow.add_edge("agent", END)
    return workflow.compile(name="multimodal_agent")

# Instantiate the agent graphs
multimodal_agent = create_multimodal_agent_graph()
web_search_agent = create_web_search_graph()
youtube_agent = create_youtube_viwer_graph()

# Create the supervisor workflow to route queries to the appropriate sub-agent
supervisor_workflow = create_supervisor(
    [web_search_agent, youtube_agent, multimodal_agent],
    model=openai_llm,
    prompt=(
        "You are a supervisor. For each question, call one of your sub-agents and return their answer directly to the user. Do not modify, summarize, or rephrase the answer."
    )
)

# Compile the supervisor agent for use in the application
supervisor_agent = supervisor_workflow.compile(name="supervisor_agent")