Spaces:
Running
Running
import os | |
import json | |
import yaml | |
import litellm | |
import logging | |
from dotenv import load_dotenv | |
from huggingface_hub import login | |
from selenium import webdriver | |
from selenium.webdriver.common.by import By | |
from selenium.webdriver.common.keys import Keys | |
from io import BytesIO | |
from PIL import Image | |
from datetime import datetime | |
import tempfile | |
import helium | |
import gradio as gr | |
from smolagents import CodeAgent, LiteLLMModel | |
from smolagents.agents import ActionStep | |
from tools.search_item_ctrl_f import SearchItemCtrlFTool | |
from tools.go_back import GoBackTool | |
from tools.close_popups import ClosePopupsTool | |
from tools.final_answer import FinalAnswerTool | |
from GRADIO_UI import GradioUI | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
litellm.set_verbose = True | |
logger.debug("Configuring litellm for gemini/gemini-2.0-flash") | |
# Load environment variables | |
load_dotenv() | |
hf_token = os.getenv("HF_TOKEN") | |
default_gemini_api_key = os.getenv("GOOGLE_API_KEY") | |
# Warn about Anthropic key | |
if os.getenv("ANTHROPIC_API_KEY"): | |
logger.warning("ANTHROPIC_API_KEY found in environment. This may cause conflicts.") | |
login(hf_token, add_to_git_credential=False) | |
# Initialize Chrome driver | |
try: | |
chrome_options = webdriver.ChromeOptions() | |
chrome_options.add_argument("--force-device-scale-factor=1") | |
chrome_options.add_argument("--window-size=1000,1350") | |
chrome_options.add_argument("--disable-pdf-viewer") | |
chrome_options.add_argument("--no-sandbox") | |
chrome_options.add_argument("--disable-dev-shm-usage") | |
chrome_options.add_argument("--window-position=0,0") | |
chrome_options.add_argument("--headless=new") | |
driver = webdriver.Chrome(options=chrome_options) | |
driver.implicitly_wait(5) | |
helium.set_driver(driver) | |
logger.info("Chrome driver initialized successfully.") | |
except Exception as e: | |
logger.error(f"Failed to initialize Chrome driver: {str(e)}") | |
raise | |
# Screenshot callback | |
def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> str: | |
from time import sleep | |
sleep(1.0) | |
driver = helium.get_driver() | |
current_step = memory_step.step_number | |
if driver is not None: | |
for previous_memory_step in agent.memory.steps: | |
if isinstance(previous_memory_step, ActionStep) and previous_memory_step.step_number < current_step: | |
previous_memory_step.observations_images = None | |
# Capture only the viewport | |
original_size = driver.get_window_size() | |
png_bytes = driver.get_screenshot_as_png() | |
image = Image.open(BytesIO(png_bytes)) | |
screenshot_dir = os.path.join(tempfile.gettempdir(), "web_agent_screenshots") | |
os.makedirs(screenshot_dir, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
screenshot_filename = f"screenshot_step_{current_step}_{timestamp}.png" | |
screenshot_path = os.path.join(screenshot_dir, screenshot_filename) | |
image.save(screenshot_path) | |
logger.info(f"Saved screenshot to: {screenshot_path}") | |
url_info = f"Current url: {driver.current_url}\nScreenshot saved at: {screenshot_path}" | |
memory_step.observations = ( | |
url_info if memory_step.observations is None else memory_step.observations + "\n" + url_info | |
) | |
return screenshot_path | |
# Load prompt templates | |
try: | |
with open("prompts.yaml", 'r') as stream: | |
prompt_templates = yaml.safe_load(stream) | |
except FileNotFoundError: | |
prompt_templates = {} | |
# Initialize tools | |
tools = [ | |
SearchItemCtrlFTool(driver=driver), | |
GoBackTool(driver=driver), | |
ClosePopupsTool(driver=driver), | |
FinalAnswerTool(driver=driver) | |
] | |
# Debug tool registration | |
for idx, tool in enumerate(tools): | |
try: | |
tool_name = getattr(tool, 'name', f'Unknown_{idx}') | |
logger.debug(f"Registering tool {idx}: {tool.__class__.__name__}, name: {tool_name}, instance: {tool}") | |
except Exception as e: | |
logger.error(f"Failed to register tool {idx}: {str(e)}") | |
# Initialize model with API key handling | |
def initialize_model(gemini_api_key=None): | |
# Log the API key being used (mask last 4 chars for security) | |
if gemini_api_key and gemini_api_key.strip(): | |
logger.debug(f"Using user-provided API key: {gemini_api_key[:-4] + '****'}") | |
else: | |
logger.debug(f"Using default API key: {default_gemini_api_key[:-4] + '****' if default_gemini_api_key else 'None'}") | |
try: | |
api_key = gemini_api_key.strip() if gemini_api_key and gemini_api_key.strip() else default_gemini_api_key | |
if not api_key: | |
raise ValueError("No valid API key provided and GOOGLE_API_KEY not set in environment") | |
return LiteLLMModel("gemini/gemini-2.0-flash", api_key=api_key) | |
except Exception as e: | |
logger.error(f"Failed to initialize LiteLLMModel: {str(e)}") | |
raise gr.Error(f"API Key Error: {str(e)}", duration=5) | |
# Initialize agent | |
def initialize_agent(gemini_api_key=None): | |
model = initialize_model(gemini_api_key) | |
agent = CodeAgent( | |
model=model, | |
tools=tools, | |
max_steps=20, | |
verbosity_level=2, | |
prompt_templates=prompt_templates, | |
step_callbacks=[save_screenshot], | |
additional_authorized_imports=["helium"] | |
) | |
agent.python_executor("from helium import *") | |
return agent | |
# Launch Gradio UI with API key support | |
try: | |
GradioUI(initialize_agent).launch() | |
except KeyboardInterrupt: | |
driver.quit() | |
logger.info("Chrome driver closed on exit.") |