File size: 7,885 Bytes
2f77fb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import os
import re
from datetime import datetime
from typing import Annotated
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from langchain_core.messages import SystemMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph_supervisor.supervisor import create_supervisor
from youtube_transcript_api import (
NoTranscriptFound,
TranscriptsDisabled,
VideoUnavailable,
YouTubeTranscriptApi,
)
from prompts import WEB_SEARCH_PROMPT, YOUTUBE_PROMPT, MULTIMODAL_PROMPT
# Load environment variables from .env file
load_dotenv()
# Initialize OpenAI LLM (gpt-4o) for general and web search tasks
openai_llm = ChatOpenAI(
model="gpt-4o",
use_responses_api=True,
api_key=os.getenv("OPENAI_API_KEY")
)
# Initialize Google Gemini LLM for YouTube and multimodal tasks
google_llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash-preview-04-17",
google_api_key=os.getenv("GOOGLE_API_KEY"),
)
class AgentState(MessagesState):
"""
State class for agent workflows, tracks the message history.
"""
messages: Annotated[list, add_messages]
class YouTubeTranscriptInput(BaseModel):
"""
Input schema for the YouTube transcript tool.
"""
video_url: str = Field(description="YouTube URL or video ID.")
raw: bool = Field(default=False, description="Include timestamps?")
@tool("youtube_transcript", args_schema=YouTubeTranscriptInput)
def youtube_transcript(video_url: str, raw: bool = False) -> str:
"""
Fetches the transcript of a YouTube video given its URL or ID.
Returns plain text (no timestamps) or raw with timestamps.
"""
# Extract video ID from URL or use as-is if already an ID
if "youtube.com" in video_url or "youtu.be" in video_url:
match = re.search(r"(?:v=|youtu.be/)([\w-]{11})", video_url)
if not match:
return "Invalid YouTube URL or ID."
video_id = match.group(1)
else:
video_id = video_url.strip()
try:
# Fetch transcript using the API
transcript = YouTubeTranscriptApi.get_transcript(video_id)
if raw:
# Return transcript with timestamps
return "\n".join(f"{int(e['start'])}s: {e['text']}" for e in transcript)
# Return plain transcript text
return " ".join(e['text'] for e in transcript)
except TranscriptsDisabled:
return "Transcripts are disabled for this video."
except NoTranscriptFound:
return "No transcript found for this video."
except VideoUnavailable:
return "This video is unavailable."
except Exception as e:
return f"An error occurred while fetching the transcript: {e}"
# List of available tools for the agent (currently only YouTube transcript)
tools = [youtube_transcript]
def create_web_search_graph() -> StateGraph:
"""
Create the web search agent graph.
Returns:
StateGraph: The compiled web search agent workflow.
"""
web_search_preview = [{"type": "web_search_preview"}]
# Bind the web search tool to the OpenAI LLM
llm_with_tools = openai_llm.bind_tools(web_search_preview)
def agent_node(state: AgentState) -> dict:
"""
Node function for handling web search queries.
Args:
state (AgentState): The current agent state.
Returns:
dict: Updated state with the LLM response.
"""
current_date = datetime.now().strftime("%B %d, %Y")
# Format the system prompt with the current date
system_message = SystemMessage(content=WEB_SEARCH_PROMPT.format(current_date=current_date))
# Re-bind tools for each invocation (defensive)
web_search_preview = [{"type": "web_search_preview"}]
response = llm_with_tools.bind_tools(web_search_preview).invoke(
[system_message] + state.get("messages")
)
return {"messages": state.get("messages") + [response]}
# Build the workflow graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_edge(START, "agent")
workflow.add_edge("agent", END)
return workflow.compile(name="web_search_agent")
def create_youtube_viwer_graph() -> StateGraph:
"""
Create the YouTube viewer agent graph.
Returns:
StateGraph: The compiled YouTube viewer agent workflow.
"""
def agent_node(state: AgentState) -> dict:
"""
Node function for handling YouTube-related queries.
Args:
state (AgentState): The current agent state.
Returns:
dict: Updated state with the LLM response.
"""
current_date = datetime.now().strftime("%B %d, %Y")
# Format the system prompt with the current date
system_message = SystemMessage(content=YOUTUBE_PROMPT.format(current_date=current_date))
# Bind the YouTube transcript tool to the Gemini LLM
llm_with_tools = google_llm.bind_tools(tools)
response = llm_with_tools.invoke([system_message] + state.get("messages"))
return {"messages": state.get("messages") + [response]}
# Build the workflow graph with tool node and conditional routing
workflow = StateGraph(AgentState)
workflow.add_node("llm", agent_node)
workflow.add_node("tools", ToolNode(tools))
workflow.set_entry_point("llm")
workflow.add_conditional_edges(
"llm",
tools_condition,
{
"tools": "tools", # If tool is needed, go to tools node
"__end__": END, # Otherwise, end the workflow
},
)
workflow.add_edge("tools", "llm") # After tool, return to LLM node
return workflow.compile(name="youtube_viwer_agent")
def create_multimodal_agent_graph() -> StateGraph:
"""
Create the multimodal agent graph using Gemini for best multimodal support.
Returns:
StateGraph: The compiled multimodal agent workflow.
"""
def agent_node(state: AgentState) -> dict:
"""
Node function for handling multimodal queries.
Args:
state (AgentState): The current agent state.
Returns:
dict: Updated state with the LLM response.
"""
current_date = datetime.now().strftime("%B %d, %Y")
# Compose the system message with the multimodal prompt and current date
system_message = SystemMessage(content=MULTIMODAL_PROMPT + f" Today's date: {current_date}.")
messages = [system_message] + state.get("messages")
# Invoke Gemini LLM for multimodal reasoning
response = google_llm.invoke(messages)
return {"messages": state.get("messages") + [response]}
# Build the workflow graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_edge(START, "agent")
workflow.add_edge("agent", END)
return workflow.compile(name="multimodal_agent")
# Instantiate the agent graphs
multimodal_agent = create_multimodal_agent_graph()
web_search_agent = create_web_search_graph()
youtube_agent = create_youtube_viwer_graph()
# Create the supervisor workflow to route queries to the appropriate sub-agent
supervisor_workflow = create_supervisor(
[web_search_agent, youtube_agent, multimodal_agent],
model=openai_llm,
prompt=(
"You are a supervisor. For each question, call one of your sub-agents and return their answer directly to the user. Do not modify, summarize, or rephrase the answer."
)
)
# Compile the supervisor agent for use in the application
supervisor_agent = supervisor_workflow.compile(name="supervisor_agent") |