Sanjeev23oct's picture
Upload folder using huggingface_hub
f1d5e1c verified
import base64
import os
import time
from pathlib import Path
from typing import Dict, Optional
import requests
import json
import gradio as gr
import uuid
from langchain_anthropic import ChatAnthropic
from langchain_mistralai import ChatMistralAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
PROVIDER_DISPLAY_NAMES = {
"openai": "OpenAI",
"azure_openai": "Azure OpenAI",
"anthropic": "Anthropic",
"deepseek": "DeepSeek",
"google": "Google",
"alibaba": "Alibaba",
"moonshot": "MoonShot",
"unbound": "Unbound AI"
}
def get_llm_model(provider: str, **kwargs):
"""
获取LLM 模型
:param provider: 模型类型
:param kwargs:
:return:
"""
if provider not in ["ollama"]:
env_var = f"{provider.upper()}_API_KEY"
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "")
if not api_key:
raise MissingAPIKeyError(provider, env_var)
kwargs["api_key"] = api_key
if provider == "anthropic":
if not kwargs.get("base_url", ""):
base_url = "https://api.anthropic.com"
else:
base_url = kwargs.get("base_url")
return ChatAnthropic(
model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == 'mistral':
if not kwargs.get("base_url", ""):
base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1")
else:
base_url = kwargs.get("base_url")
if not kwargs.get("api_key", ""):
api_key = os.getenv("MISTRAL_API_KEY", "")
else:
api_key = kwargs.get("api_key")
return ChatMistralAI(
model=kwargs.get("model_name", "mistral-large-latest"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "deepseek":
if not kwargs.get("base_url", ""):
base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner":
return DeepSeekR1ChatOpenAI(
model=kwargs.get("model_name", "deepseek-reasoner"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
else:
return ChatOpenAI(
model=kwargs.get("model_name", "deepseek-chat"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "google":
return ChatGoogleGenerativeAI(
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
temperature=kwargs.get("temperature", 0.0),
api_key=api_key,
)
elif provider == "ollama":
if not kwargs.get("base_url", ""):
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
else:
base_url = kwargs.get("base_url")
if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):
return DeepSeekR1ChatOllama(
model=kwargs.get("model_name", "deepseek-r1:14b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
base_url=base_url,
)
else:
return ChatOllama(
model=kwargs.get("model_name", "qwen2.5:7b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
num_predict=kwargs.get("num_predict", 1024),
base_url=base_url,
)
elif provider == "azure_openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
return AzureChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
temperature=kwargs.get("temperature", 0.0),
api_version=api_version,
azure_endpoint=base_url,
api_key=api_key,
)
elif provider == "alibaba":
if not kwargs.get("base_url", ""):
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
model=kwargs.get("model_name", "qwen-plus"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "moonshot":
return ChatOpenAI(
model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"),
temperature=kwargs.get("temperature", 0.0),
base_url=os.getenv("MOONSHOT_ENDPOINT"),
api_key=os.getenv("MOONSHOT_API_KEY"),
)
elif provider == "unbound":
return ChatOpenAI(
model=kwargs.get("model_name", "gpt-4o-mini"),
temperature=kwargs.get("temperature", 0.0),
base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"),
api_key=api_key,
)
elif provider == "siliconflow":
if not kwargs.get("api_key", ""):
api_key = os.getenv("SiliconFLOW_API_KEY", "")
else:
api_key = kwargs.get("api_key")
if not kwargs.get("base_url", ""):
base_url = os.getenv("SiliconFLOW_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
api_key=api_key,
base_url=base_url,
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
temperature=kwargs.get("temperature", 0.0),
)
else:
raise ValueError(f"Unsupported provider: {provider}")
# Predefined model names for common providers
model_names = {
"anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"],
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
"google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest",
"gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"],
"ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b",
"deepseek-r1:14b", "deepseek-r1:32b"],
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"],
"moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"],
"unbound": ["gemini-2.0-flash", "gpt-4o-mini", "gpt-4o", "gpt-4.5-preview"],
"siliconflow": [
"deepseek-ai/DeepSeek-R1",
"deepseek-ai/DeepSeek-V3",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"deepseek-ai/DeepSeek-V2.5",
"deepseek-ai/deepseek-vl2",
"Qwen/Qwen2.5-72B-Instruct-128K",
"Qwen/Qwen2.5-72B-Instruct",
"Qwen/Qwen2.5-32B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-Coder-7B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"Qwen/QwQ-32B-Preview",
"Qwen/Qwen2-VL-72B-Instruct",
"Qwen/Qwen2.5-VL-32B-Instruct",
"Qwen/Qwen2.5-VL-72B-Instruct",
"TeleAI/TeleChat2",
"THUDM/glm-4-9b-chat",
"Vendor-A/Qwen/Qwen2.5-72B-Instruct",
"internlm/internlm2_5-7b-chat",
"internlm/internlm2_5-20b-chat",
"Pro/Qwen/Qwen2.5-7B-Instruct",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/THUDM/chatglm3-6b",
"Pro/THUDM/glm-4-9b-chat",
],
}
# Callback to update the model name dropdown based on the selected provider
def update_model_dropdown(llm_provider, api_key=None, base_url=None):
"""
Update the model name dropdown with predefined models for the selected provider.
"""
import gradio as gr
# Use API keys from .env if not provided
if not api_key:
api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "")
if not base_url:
base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
# Use predefined models for the selected provider
if llm_provider in model_names:
return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True)
else:
return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
class MissingAPIKeyError(Exception):
"""Custom exception for missing API key."""
def __init__(self, provider: str, env_var: str):
provider_display = PROVIDER_DISPLAY_NAMES.get(provider, provider.upper())
super().__init__(f"💥 {provider_display} API key not found! 🔑 Please set the "
f"`{env_var}` environment variable or provide it in the UI.")
def encode_image(img_path):
if not img_path:
return None
with open(img_path, "rb") as fin:
image_data = base64.b64encode(fin.read()).decode("utf-8")
return image_data
def get_latest_files(directory: str, file_types: list = ['.webm', '.zip']) -> Dict[str, Optional[str]]:
"""Get the latest recording and trace files"""
latest_files: Dict[str, Optional[str]] = {ext: None for ext in file_types}
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
return latest_files
for file_type in file_types:
try:
matches = list(Path(directory).rglob(f"*{file_type}"))
if matches:
latest = max(matches, key=lambda p: p.stat().st_mtime)
# Only return files that are complete (not being written)
if time.time() - latest.stat().st_mtime > 1.0:
latest_files[file_type] = str(latest)
except Exception as e:
print(f"Error getting latest {file_type} file: {e}")
return latest_files
async def capture_screenshot(browser_context):
"""Capture and encode a screenshot"""
# Extract the Playwright browser instance
playwright_browser = browser_context.browser.playwright_browser # Ensure this is correct.
# Check if the browser instance is valid and if an existing context can be reused
if playwright_browser and playwright_browser.contexts:
playwright_context = playwright_browser.contexts[0]
else:
return None
# Access pages in the context
pages = None
if playwright_context:
pages = playwright_context.pages
# Use an existing page or create a new one if none exist
if pages:
active_page = pages[0]
for page in pages:
if page.url != "about:blank":
active_page = page
else:
return None
# Take screenshot
try:
screenshot = await active_page.screenshot(
type='jpeg',
quality=75,
scale="css"
)
encoded = base64.b64encode(screenshot).decode('utf-8')
return encoded
except Exception as e:
return None
class ConfigManager:
def __init__(self):
self.components = {}
self.component_order = []
def register_component(self, name: str, component):
"""Register a gradio component for config management."""
self.components[name] = component
if name not in self.component_order:
self.component_order.append(name)
return component
def save_current_config(self):
"""Save the current configuration of all registered components."""
current_config = {}
for name in self.component_order:
component = self.components[name]
# Get the current value from the component
current_config[name] = getattr(component, "value", None)
return save_config_to_file(current_config)
def update_ui_from_config(self, config_file):
"""Update UI components from a loaded configuration file."""
if config_file is None:
return [gr.update() for _ in self.component_order] + ["No file selected."]
loaded_config = load_config_from_file(config_file.name)
if not isinstance(loaded_config, dict):
return [gr.update() for _ in self.component_order] + ["Error: Invalid configuration file."]
# Prepare updates for all components
updates = []
for name in self.component_order:
if name in loaded_config:
updates.append(gr.update(value=loaded_config[name]))
else:
updates.append(gr.update())
updates.append("Configuration loaded successfully.")
return updates
def get_all_components(self):
"""Return all registered components in the order they were registered."""
return [self.components[name] for name in self.component_order]
def load_config_from_file(config_file):
"""Load settings from a config file (JSON format)."""
try:
with open(config_file, 'r') as f:
settings = json.load(f)
return settings
except Exception as e:
return f"Error loading configuration: {str(e)}"
def save_config_to_file(settings, save_dir="./tmp/webui_settings"):
"""Save the current settings to a UUID.json file with a UUID name."""
os.makedirs(save_dir, exist_ok=True)
config_file = os.path.join(save_dir, f"{uuid.uuid4()}.json")
with open(config_file, 'w') as f:
json.dump(settings, f, indent=2)
return f"Configuration saved to {config_file}"