File size: 2,642 Bytes
ac6a4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Utility & helper functions."""

import os
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage
import asyncio
from datetime import UTC, datetime
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS


# Load environment variables from .env file
load_dotenv()


def get_message_text(msg: BaseMessage) -> str:
    """Get the text content of a message."""
    content = msg.content
    if isinstance(content, str):
        return content
    elif isinstance(content, dict):
        return content.get("text", "")
    else:
        txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content]
        return "".join(txts).strip()


def format_system_prompt(prompt_template: str) -> str:
    """Format a system prompt template with current system time and available agents.
    
    Args:
        prompt_template: The prompt template to format
        
    Returns:
        The formatted prompt with system time and agent information
    """
    # Get example workers for templates
    example_worker_1 = WORKERS[0] if WORKERS else "researcher"
    example_worker_2 = WORKERS[1] if len(WORKERS) > 1 else "coder"
    
    # Get verdicts for templates
    correct_verdict = VERDICTS[0] if VERDICTS else "CORRECT"
    retry_verdict = VERDICTS[1] if len(VERDICTS) > 1 else "RETRY"
    
    return prompt_template.format(
        system_time=datetime.now(tz=UTC).isoformat(),
        workers=", ".join(WORKERS),
        members=", ".join(MEMBERS),
        worker_options=", ".join([f'"{w}"' for w in WORKERS]),
        example_worker_1=example_worker_1,
        example_worker_2=example_worker_2,
        correct_verdict=correct_verdict,
        retry_verdict=retry_verdict
    )


def load_chat_model(fully_specified_name: str) -> BaseChatModel:
    """Load a chat model from a fully specified name.

    Args:
        fully_specified_name (str): String in the format 'provider/model'.
    """
    provider, model = fully_specified_name.split("/", maxsplit=1)
    
    # Special handling for Google Genai models to ensure they're configured for async
    if provider == "google_genai":
        from langchain_google_genai import ChatGoogleGenerativeAI
        
        # Make sure we have the API key
        if not os.environ.get("GOOGLE_API_KEY"):
            raise ValueError("GOOGLE_API_KEY environment variable is required for google_genai models")
        
        return ChatGoogleGenerativeAI(model=model)
    else:
        return init_chat_model(model, model_provider=provider)