Spaces:
Running
Running
import base64 | |
import os | |
import time | |
from pathlib import Path | |
from typing import Dict, Optional | |
import requests | |
import json | |
import gradio as gr | |
import uuid | |
from langchain_anthropic import ChatAnthropic | |
from langchain_mistralai import ChatMistralAI | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_ollama import ChatOllama | |
from langchain_openai import AzureChatOpenAI, ChatOpenAI | |
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama | |
PROVIDER_DISPLAY_NAMES = { | |
"openai": "OpenAI", | |
"azure_openai": "Azure OpenAI", | |
"anthropic": "Anthropic", | |
"deepseek": "DeepSeek", | |
"google": "Google", | |
"alibaba": "Alibaba", | |
"moonshot": "MoonShot", | |
"unbound": "Unbound AI" | |
} | |
def get_llm_model(provider: str, **kwargs): | |
""" | |
获取LLM 模型 | |
:param provider: 模型类型 | |
:param kwargs: | |
:return: | |
""" | |
if provider not in ["ollama"]: | |
env_var = f"{provider.upper()}_API_KEY" | |
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "") | |
if not api_key: | |
raise MissingAPIKeyError(provider, env_var) | |
kwargs["api_key"] = api_key | |
if provider == "anthropic": | |
if not kwargs.get("base_url", ""): | |
base_url = "https://api.anthropic.com" | |
else: | |
base_url = kwargs.get("base_url") | |
return ChatAnthropic( | |
model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"), | |
temperature=kwargs.get("temperature", 0.0), | |
base_url=base_url, | |
api_key=api_key, | |
) | |
elif provider == 'mistral': | |
if not kwargs.get("base_url", ""): | |
base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1") | |
else: | |
base_url = kwargs.get("base_url") | |
if not kwargs.get("api_key", ""): | |
api_key = os.getenv("MISTRAL_API_KEY", "") | |
else: | |
api_key = kwargs.get("api_key") | |
return ChatMistralAI( | |
model=kwargs.get("model_name", "mistral-large-latest"), | |
temperature=kwargs.get("temperature", 0.0), | |
base_url=base_url, | |
api_key=api_key, | |
) | |
elif provider == "openai": | |
if not kwargs.get("base_url", ""): | |
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") | |
else: | |
base_url = kwargs.get("base_url") | |
return ChatOpenAI( | |
model=kwargs.get("model_name", "gpt-4o"), | |
temperature=kwargs.get("temperature", 0.0), | |
base_url=base_url, | |
api_key=api_key, | |
) | |
elif provider == "deepseek": | |
if not kwargs.get("base_url", ""): | |
base_url = os.getenv("DEEPSEEK_ENDPOINT", "") | |
else: | |
base_url = kwargs.get("base_url") | |
if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner": | |
return DeepSeekR1ChatOpenAI( | |
model=kwargs.get("model_name", "deepseek-reasoner"), | |
temperature=kwargs.get("temperature", 0.0), | |
base_url=base_url, | |
api_key=api_key, | |
) | |
else: | |
return ChatOpenAI( | |
model=kwargs.get("model_name", "deepseek-chat"), | |
temperature=kwargs.get("temperature", 0.0), | |
base_url=base_url, | |
api_key=api_key, | |
) | |
elif provider == "google": | |
return ChatGoogleGenerativeAI( | |
model=kwargs.get("model_name", "gemini-2.0-flash-exp"), | |
temperature=kwargs.get("temperature", 0.0), | |
api_key=api_key, | |
) | |
elif provider == "ollama": | |
if not kwargs.get("base_url", ""): | |
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434") | |
else: | |
base_url = kwargs.get("base_url") | |
if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"): | |
return DeepSeekR1ChatOllama( | |
model=kwargs.get("model_name", "deepseek-r1:14b"), | |
temperature=kwargs.get("temperature", 0.0), | |
num_ctx=kwargs.get("num_ctx", 32000), | |
base_url=base_url, | |
) | |
else: | |
return ChatOllama( | |
model=kwargs.get("model_name", "qwen2.5:7b"), | |
temperature=kwargs.get("temperature", 0.0), | |
num_ctx=kwargs.get("num_ctx", 32000), | |
num_predict=kwargs.get("num_predict", 1024), | |
base_url=base_url, | |
) | |
elif provider == "azure_openai": | |
if not kwargs.get("base_url", ""): | |
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "") | |
else: | |
base_url = kwargs.get("base_url") | |
api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview") | |
return AzureChatOpenAI( | |
model=kwargs.get("model_name", "gpt-4o"), | |
temperature=kwargs.get("temperature", 0.0), | |
api_version=api_version, | |
azure_endpoint=base_url, | |
api_key=api_key, | |
) | |
elif provider == "alibaba": | |
if not kwargs.get("base_url", ""): | |
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1") | |
else: | |
base_url = kwargs.get("base_url") | |
return ChatOpenAI( | |
model=kwargs.get("model_name", "qwen-plus"), | |
temperature=kwargs.get("temperature", 0.0), | |
base_url=base_url, | |
api_key=api_key, | |
) | |
elif provider == "moonshot": | |
return ChatOpenAI( | |
model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"), | |
temperature=kwargs.get("temperature", 0.0), | |
base_url=os.getenv("MOONSHOT_ENDPOINT"), | |
api_key=os.getenv("MOONSHOT_API_KEY"), | |
) | |
elif provider == "unbound": | |
return ChatOpenAI( | |
model=kwargs.get("model_name", "gpt-4o-mini"), | |
temperature=kwargs.get("temperature", 0.0), | |
base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"), | |
api_key=api_key, | |
) | |
elif provider == "siliconflow": | |
if not kwargs.get("api_key", ""): | |
api_key = os.getenv("SiliconFLOW_API_KEY", "") | |
else: | |
api_key = kwargs.get("api_key") | |
if not kwargs.get("base_url", ""): | |
base_url = os.getenv("SiliconFLOW_ENDPOINT", "") | |
else: | |
base_url = kwargs.get("base_url") | |
return ChatOpenAI( | |
api_key=api_key, | |
base_url=base_url, | |
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), | |
temperature=kwargs.get("temperature", 0.0), | |
) | |
else: | |
raise ValueError(f"Unsupported provider: {provider}") | |
# Predefined model names for common providers | |
model_names = { | |
"anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"], | |
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"], | |
"deepseek": ["deepseek-chat", "deepseek-reasoner"], | |
"google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", | |
"gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"], | |
"ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b", | |
"deepseek-r1:14b", "deepseek-r1:32b"], | |
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"], | |
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"], | |
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"], | |
"moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"], | |
"unbound": ["gemini-2.0-flash", "gpt-4o-mini", "gpt-4o", "gpt-4.5-preview"], | |
"siliconflow": [ | |
"deepseek-ai/DeepSeek-R1", | |
"deepseek-ai/DeepSeek-V3", | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", | |
"deepseek-ai/DeepSeek-V2.5", | |
"deepseek-ai/deepseek-vl2", | |
"Qwen/Qwen2.5-72B-Instruct-128K", | |
"Qwen/Qwen2.5-72B-Instruct", | |
"Qwen/Qwen2.5-32B-Instruct", | |
"Qwen/Qwen2.5-14B-Instruct", | |
"Qwen/Qwen2.5-7B-Instruct", | |
"Qwen/Qwen2.5-Coder-32B-Instruct", | |
"Qwen/Qwen2.5-Coder-7B-Instruct", | |
"Qwen/Qwen2-7B-Instruct", | |
"Qwen/Qwen2-1.5B-Instruct", | |
"Qwen/QwQ-32B-Preview", | |
"Qwen/Qwen2-VL-72B-Instruct", | |
"Qwen/Qwen2.5-VL-32B-Instruct", | |
"Qwen/Qwen2.5-VL-72B-Instruct", | |
"TeleAI/TeleChat2", | |
"THUDM/glm-4-9b-chat", | |
"Vendor-A/Qwen/Qwen2.5-72B-Instruct", | |
"internlm/internlm2_5-7b-chat", | |
"internlm/internlm2_5-20b-chat", | |
"Pro/Qwen/Qwen2.5-7B-Instruct", | |
"Pro/Qwen/Qwen2-7B-Instruct", | |
"Pro/Qwen/Qwen2-1.5B-Instruct", | |
"Pro/THUDM/chatglm3-6b", | |
"Pro/THUDM/glm-4-9b-chat", | |
], | |
} | |
# Callback to update the model name dropdown based on the selected provider | |
def update_model_dropdown(llm_provider, api_key=None, base_url=None): | |
""" | |
Update the model name dropdown with predefined models for the selected provider. | |
""" | |
import gradio as gr | |
# Use API keys from .env if not provided | |
if not api_key: | |
api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "") | |
if not base_url: | |
base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "") | |
# Use predefined models for the selected provider | |
if llm_provider in model_names: | |
return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True) | |
else: | |
return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True) | |
class MissingAPIKeyError(Exception): | |
"""Custom exception for missing API key.""" | |
def __init__(self, provider: str, env_var: str): | |
provider_display = PROVIDER_DISPLAY_NAMES.get(provider, provider.upper()) | |
super().__init__(f"💥 {provider_display} API key not found! 🔑 Please set the " | |
f"`{env_var}` environment variable or provide it in the UI.") | |
def encode_image(img_path): | |
if not img_path: | |
return None | |
with open(img_path, "rb") as fin: | |
image_data = base64.b64encode(fin.read()).decode("utf-8") | |
return image_data | |
def get_latest_files(directory: str, file_types: list = ['.webm', '.zip']) -> Dict[str, Optional[str]]: | |
"""Get the latest recording and trace files""" | |
latest_files: Dict[str, Optional[str]] = {ext: None for ext in file_types} | |
if not os.path.exists(directory): | |
os.makedirs(directory, exist_ok=True) | |
return latest_files | |
for file_type in file_types: | |
try: | |
matches = list(Path(directory).rglob(f"*{file_type}")) | |
if matches: | |
latest = max(matches, key=lambda p: p.stat().st_mtime) | |
# Only return files that are complete (not being written) | |
if time.time() - latest.stat().st_mtime > 1.0: | |
latest_files[file_type] = str(latest) | |
except Exception as e: | |
print(f"Error getting latest {file_type} file: {e}") | |
return latest_files | |
async def capture_screenshot(browser_context): | |
"""Capture and encode a screenshot""" | |
# Extract the Playwright browser instance | |
playwright_browser = browser_context.browser.playwright_browser # Ensure this is correct. | |
# Check if the browser instance is valid and if an existing context can be reused | |
if playwright_browser and playwright_browser.contexts: | |
playwright_context = playwright_browser.contexts[0] | |
else: | |
return None | |
# Access pages in the context | |
pages = None | |
if playwright_context: | |
pages = playwright_context.pages | |
# Use an existing page or create a new one if none exist | |
if pages: | |
active_page = pages[0] | |
for page in pages: | |
if page.url != "about:blank": | |
active_page = page | |
else: | |
return None | |
# Take screenshot | |
try: | |
screenshot = await active_page.screenshot( | |
type='jpeg', | |
quality=75, | |
scale="css" | |
) | |
encoded = base64.b64encode(screenshot).decode('utf-8') | |
return encoded | |
except Exception as e: | |
return None | |
class ConfigManager: | |
def __init__(self): | |
self.components = {} | |
self.component_order = [] | |
def register_component(self, name: str, component): | |
"""Register a gradio component for config management.""" | |
self.components[name] = component | |
if name not in self.component_order: | |
self.component_order.append(name) | |
return component | |
def save_current_config(self): | |
"""Save the current configuration of all registered components.""" | |
current_config = {} | |
for name in self.component_order: | |
component = self.components[name] | |
# Get the current value from the component | |
current_config[name] = getattr(component, "value", None) | |
return save_config_to_file(current_config) | |
def update_ui_from_config(self, config_file): | |
"""Update UI components from a loaded configuration file.""" | |
if config_file is None: | |
return [gr.update() for _ in self.component_order] + ["No file selected."] | |
loaded_config = load_config_from_file(config_file.name) | |
if not isinstance(loaded_config, dict): | |
return [gr.update() for _ in self.component_order] + ["Error: Invalid configuration file."] | |
# Prepare updates for all components | |
updates = [] | |
for name in self.component_order: | |
if name in loaded_config: | |
updates.append(gr.update(value=loaded_config[name])) | |
else: | |
updates.append(gr.update()) | |
updates.append("Configuration loaded successfully.") | |
return updates | |
def get_all_components(self): | |
"""Return all registered components in the order they were registered.""" | |
return [self.components[name] for name in self.component_order] | |
def load_config_from_file(config_file): | |
"""Load settings from a config file (JSON format).""" | |
try: | |
with open(config_file, 'r') as f: | |
settings = json.load(f) | |
return settings | |
except Exception as e: | |
return f"Error loading configuration: {str(e)}" | |
def save_config_to_file(settings, save_dir="./tmp/webui_settings"): | |
"""Save the current settings to a UUID.json file with a UUID name.""" | |
os.makedirs(save_dir, exist_ok=True) | |
config_file = os.path.join(save_dir, f"{uuid.uuid4()}.json") | |
with open(config_file, 'w') as f: | |
json.dump(settings, f, indent=2) | |
return f"Configuration saved to {config_file}" | |