Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Gradio Interface for Multimodal Chat with SSH Tunnel Keepalive, GPU Monitoring, and API Fallback | |
This application provides a Gradio web interface for multimodal chat with a | |
local vLLM model. It establishes SSH tunnels to a local vLLM server and | |
the nvidia-smi monitoring endpoint, with fallback to Hyperbolic API if needed. | |
""" | |
import os | |
import time | |
import threading | |
import logging | |
import base64 | |
import json | |
import requests | |
from io import BytesIO | |
import gradio as gr | |
from openai import OpenAI | |
from ssh_tunneler import SSHTunnel | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger('app') | |
# Get environment variables | |
SSH_HOST = os.environ.get('SSH_HOST') | |
SSH_PORT = int(os.environ.get('SSH_PORT', 22)) | |
SSH_USERNAME = os.environ.get('SSH_USERNAME') | |
SSH_PASSWORD = os.environ.get('SSH_PASSWORD') | |
REMOTE_PORT = int(os.environ.get('REMOTE_PORT', 8000)) # vLLM API port on remote machine | |
LOCAL_PORT = int(os.environ.get('LOCAL_PORT', 8020)) # Local forwarded port | |
GPU_REMOTE_PORT = 5000 # GPU monitoring endpoint on remote machine | |
GPU_LOCAL_PORT = 5020 # Local forwarded port for GPU monitoring | |
VLLM_MODEL = os.environ.get('MODEL_NAME', 'google/gemma-3-27b-it') | |
HYPERBOLIC_KEY = os.environ.get('HYPERBOLIC_XYZ_KEY') | |
FALLBACK_MODEL = 'Qwen/Qwen2.5-VL-72B-Instruct' # Fallback model at Hyperbolic | |
# Set the maximum number of concurrent API calls before queuing | |
MAX_CONCURRENT = int(os.environ.get('MAX_CONCURRENT', 3)) # Default to 3 concurrent calls | |
# API endpoints | |
VLLM_ENDPOINT = "http://localhost:" + str(LOCAL_PORT) + "/v1" | |
HYPERBOLIC_ENDPOINT = "https://api.hyperbolic.xyz/v1" | |
GPU_JSON_ENDPOINT = "http://localhost:" + str(GPU_LOCAL_PORT) + "/gpu/json" | |
GPU_TXT_ENDPOINT = "http://localhost:" + str(GPU_LOCAL_PORT) + "/gpu/txt" # For backward compatibility | |
# Global variables | |
api_tunnel = None | |
gpu_tunnel = None | |
use_fallback = False # Whether to use fallback API instead of local vLLM | |
api_tunnel_status = {"is_running": False, "message": "Initializing API tunnel..."} | |
gpu_tunnel_status = {"is_running": False, "message": "Initializing GPU monitoring tunnel..."} | |
gpu_data = {"timestamp": "", "gpus": [], "processes": [], "success": False} | |
gpu_monitor_thread = None | |
gpu_monitor_running = False | |
def start_ssh_tunnels(): | |
""" | |
Start the SSH tunnels and monitor their status. | |
""" | |
global api_tunnel, gpu_tunnel, use_fallback, api_tunnel_status, gpu_tunnel_status | |
if not all([SSH_HOST, SSH_USERNAME, SSH_PASSWORD]): | |
logger.error("Missing SSH connection details. Falling back to Hyperbolic API.") | |
use_fallback = True | |
api_tunnel_status = {"is_running": False, "message": "Missing SSH credentials"} | |
gpu_tunnel_status = {"is_running": False, "message": "Missing SSH credentials"} | |
return | |
try: | |
# Start API tunnel | |
logger.info("Starting API SSH tunnel...") | |
api_tunnel = SSHTunnel( | |
ssh_host=SSH_HOST, | |
ssh_port=SSH_PORT, | |
username=SSH_USERNAME, | |
password=SSH_PASSWORD, | |
remote_port=REMOTE_PORT, | |
local_port=LOCAL_PORT, | |
reconnect_interval=30, | |
keep_alive_interval=15 | |
) | |
if api_tunnel.start(): | |
logger.info("API SSH tunnel started successfully") | |
api_tunnel_status = {"is_running": True, "message": "Connected"} | |
else: | |
logger.warning("Failed to start API SSH tunnel. Falling back to Hyperbolic API.") | |
use_fallback = True | |
api_tunnel_status = {"is_running": False, "message": "Connection failed"} | |
# Start GPU monitoring tunnel | |
logger.info("Starting GPU monitoring SSH tunnel...") | |
gpu_tunnel = SSHTunnel( | |
ssh_host=SSH_HOST, | |
ssh_port=SSH_PORT, | |
username=SSH_USERNAME, | |
password=SSH_PASSWORD, | |
remote_port=GPU_REMOTE_PORT, | |
local_port=GPU_LOCAL_PORT, | |
reconnect_interval=30, | |
keep_alive_interval=15 | |
) | |
if gpu_tunnel.start(): | |
logger.info("GPU monitoring SSH tunnel started successfully") | |
gpu_tunnel_status = {"is_running": True, "message": "Connected"} | |
# Start GPU monitoring | |
start_gpu_monitoring() | |
else: | |
logger.warning("Failed to start GPU monitoring SSH tunnel.") | |
gpu_tunnel_status = {"is_running": False, "message": "Connection failed"} | |
except Exception as e: | |
logger.error(f"Error starting SSH tunnels: {str(e)}") | |
use_fallback = True | |
api_tunnel_status = {"is_running": False, "message": "Connection error"} | |
gpu_tunnel_status = {"is_running": False, "message": "Connection error"} | |
def check_vllm_api_health(): | |
""" | |
Check if the vLLM API is actually responding by querying the /v1/models endpoint. | |
Returns: | |
tuple: (is_healthy, message) | |
""" | |
try: | |
response = requests.get(f"{VLLM_ENDPOINT}/models", timeout=5) | |
if response.status_code == 200: | |
try: | |
data = response.json() | |
if 'data' in data and len(data['data']) > 0: | |
model_id = data['data'][0].get('id', 'Unknown model') | |
return True, f"API is healthy. Available model: {model_id}" | |
else: | |
return True, "API is healthy but no models found" | |
except Exception as e: | |
return False, f"API returned 200 but invalid JSON: {str(e)}" | |
else: | |
return False, f"API returned status code: {response.status_code}" | |
except Exception as e: | |
return False, f"API request failed: {str(e)}" | |
def fetch_gpu_info(): | |
""" | |
Fetch GPU information from the remote server in JSON format. | |
Returns: | |
dict: GPU information or error message | |
""" | |
global gpu_tunnel_status | |
try: | |
response = requests.get(GPU_JSON_ENDPOINT, timeout=5) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
logger.warning(f"Error fetching GPU info: HTTP {response.status_code}") | |
return { | |
"success": False, | |
"error": f"HTTP Error: {response.status_code}", | |
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |
"gpus": [], | |
"processes": [] | |
} | |
except Exception as e: | |
logger.warning(f"Error fetching GPU info: {str(e)}") | |
return { | |
"success": False, | |
"error": str(e), | |
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |
"gpus": [], | |
"processes": [] | |
} | |
def fetch_gpu_text(): | |
""" | |
Fetch raw nvidia-smi output from the remote server for backward compatibility. | |
Returns: | |
str: nvidia-smi output or error message | |
""" | |
try: | |
response = requests.get(GPU_TXT_ENDPOINT, timeout=5) | |
if response.status_code == 200: | |
return response.text | |
else: | |
return f"Error fetching GPU info: HTTP {response.status_code}" | |
except Exception as e: | |
return f"Error fetching GPU info: {str(e)}" | |
def start_gpu_monitoring(): | |
""" | |
Start the GPU monitoring thread. | |
""" | |
global gpu_monitor_thread, gpu_monitor_running, gpu_data | |
if gpu_monitor_running: | |
return | |
gpu_monitor_running = True | |
def monitor_loop(): | |
global gpu_data | |
while gpu_monitor_running: | |
try: | |
gpu_data = fetch_gpu_info() | |
except Exception as e: | |
logger.error(f"Error in GPU monitoring loop: {str(e)}") | |
gpu_data = { | |
"success": False, | |
"error": str(e), | |
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |
"gpus": [], | |
"processes": [] | |
} | |
time.sleep(2) # Update every 2 seconds | |
gpu_monitor_thread = threading.Thread(target=monitor_loop, daemon=True) | |
gpu_monitor_thread.start() | |
logger.info("GPU monitoring thread started") | |
def process_chat(message_dict, history): | |
""" | |
Process user message and send to the appropriate API. | |
Args: | |
message_dict (dict): User message containing text and files | |
history (list): Chat history | |
Returns: | |
list: Updated chat history | |
""" | |
global use_fallback | |
text = message_dict.get("text", "") | |
files = message_dict.get("files", []) | |
if not history: | |
history = [] | |
if files: | |
for file in files: | |
history.append({"role": "user", "content": (file,)}) | |
if text.strip(): | |
history.append({"role": "user", "content": text}) | |
else: | |
if not files: | |
history.append({"role": "user", "content": ""}) | |
base64_images = convert_files_to_base64(files) | |
openai_messages = [] | |
for h in history: | |
if h["role"] == "user": | |
if isinstance(h["content"], tuple): | |
continue | |
else: | |
openai_messages.append({ | |
"role": "user", | |
"content": h["content"] | |
}) | |
elif h["role"] == "assistant": | |
openai_messages.append({ | |
"role": "assistant", | |
"content": h["content"] | |
}) | |
if base64_images: | |
if openai_messages and openai_messages[-1]["role"] == "user": | |
last_msg = openai_messages[-1] | |
content_list = [] | |
if last_msg["content"]: | |
content_list.append({"type": "text", "text": last_msg["content"]}) | |
for img_b64 in base64_images: | |
content_list.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{img_b64}" | |
} | |
}) | |
last_msg["content"] = content_list | |
try: | |
client = get_openai_client() | |
model = get_model_name() | |
response = client.chat.completions.create( | |
model=model, | |
messages=openai_messages, | |
stream=True | |
) | |
assistant_message = "" | |
for chunk in response: | |
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None: | |
assistant_message += chunk.choices[0].delta.content | |
history_with_stream = history.copy() | |
history_with_stream.append({"role": "assistant", "content": assistant_message}) | |
yield history_with_stream | |
if not assistant_message: | |
assistant_message = "No response received from the model." | |
if not history or history[-1]["role"] != "assistant": | |
history.append({"role": "assistant", "content": assistant_message}) | |
return history | |
except Exception as primary_error: | |
logger.error(f"Primary API error: {str(primary_error)}") | |
if not use_fallback: | |
try: | |
logger.info("Falling back to Hyperbolic API") | |
client = get_openai_client(use_fallback_api=True) | |
model = get_model_name(use_fallback_api=True) | |
response = client.chat.completions.create( | |
model=model, | |
messages=openai_messages, | |
stream=True | |
) | |
assistant_message = "" | |
for chunk in response: | |
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None: | |
assistant_message += chunk.choices[0].delta.content | |
history_with_stream = history.copy() | |
history_with_stream.append({"role": "assistant", "content": assistant_message}) | |
yield history_with_stream | |
if not assistant_message: | |
assistant_message = "No response received from the fallback model." | |
if not history or history[-1]["role"] != "assistant": | |
history.append({"role": "assistant", "content": assistant_message}) | |
use_fallback = True | |
return history | |
except Exception as fallback_error: | |
logger.error(f"Fallback API error: {str(fallback_error)}") | |
error_msg = "Error connecting to both primary and fallback APIs." | |
history.append({"role": "assistant", "content": error_msg}) | |
return history | |
else: | |
error_msg = "An error occurred with the model service." | |
history.append({"role": "assistant", "content": error_msg}) | |
return history | |
def monitor_tunnels(): | |
""" | |
Monitor the SSH tunnels status and update the global variables. | |
""" | |
global api_tunnel, gpu_tunnel, use_fallback, api_tunnel_status, gpu_tunnel_status | |
logger.info("Starting tunnel monitoring thread") | |
while True: | |
try: | |
if api_tunnel is not None: | |
ssh_status = api_tunnel.check_status() | |
if ssh_status["is_running"]: | |
is_healthy, message = check_vllm_api_health() | |
if is_healthy: | |
use_fallback = False | |
api_tunnel_status = { | |
"is_running": True, | |
"message": f"Connected and healthy. {message}" | |
} | |
else: | |
use_fallback = True | |
api_tunnel_status = { | |
"is_running": False, | |
"message": "Tunnel connected but vLLM API unhealthy" | |
} | |
else: | |
logger.error(f"API SSH tunnel disconnected: {ssh_status.get('error', 'Unknown error')}") | |
use_fallback = True | |
api_tunnel_status = { | |
"is_running": False, | |
"message": "Disconnected - Check server status" | |
} | |
else: | |
use_fallback = True | |
api_tunnel_status = {"is_running": False, "message": "Tunnel not initialized"} | |
if gpu_tunnel is not None: | |
ssh_status = gpu_tunnel.check_status() | |
if ssh_status["is_running"]: | |
gpu_tunnel_status = { | |
"is_running": True, | |
"message": "Connected" | |
} | |
if not gpu_monitor_running: | |
start_gpu_monitoring() | |
else: | |
logger.error(f"GPU SSH tunnel disconnected: {ssh_status.get('error', 'Unknown error')}") | |
gpu_tunnel_status = { | |
"is_running": False, | |
"message": "Disconnected - Check server status" | |
} | |
else: | |
gpu_tunnel_status = {"is_running": False, "message": "Tunnel not initialized"} | |
except Exception as e: | |
logger.error(f"Error monitoring tunnels: {str(e)}") | |
use_fallback = True | |
api_tunnel_status = {"is_running": False, "message": "Monitoring error"} | |
gpu_tunnel_status = {"is_running": False, "message": "Monitoring error"} | |
time.sleep(5) # Check every 5 seconds | |
def get_openai_client(use_fallback_api=None): | |
""" | |
Create and return an OpenAI client configured for the appropriate endpoint. | |
Args: | |
use_fallback_api (bool): If True, use Hyperbolic API. If False, use local vLLM. | |
If None, use the global use_fallback setting. | |
Returns: | |
OpenAI: Configured OpenAI client | |
""" | |
global use_fallback | |
if use_fallback_api is None: | |
use_fallback_api = use_fallback | |
if use_fallback_api: | |
logger.info("Using Hyperbolic API") | |
return OpenAI( | |
api_key=HYPERBOLIC_KEY, | |
base_url=HYPERBOLIC_ENDPOINT | |
) | |
else: | |
logger.info("Using local vLLM API") | |
return OpenAI( | |
api_key="EMPTY", # vLLM doesn't require an actual API key | |
base_url=VLLM_ENDPOINT | |
) | |
def get_model_name(use_fallback_api=None): | |
""" | |
Return the appropriate model name based on the API being used. | |
Args: | |
use_fallback_api (bool): If True, use fallback model. If None, use the global setting. | |
Returns: | |
str: Model name | |
""" | |
global use_fallback | |
if use_fallback_api is None: | |
use_fallback_api = use_fallback | |
return FALLBACK_MODEL if use_fallback_api else VLLM_MODEL | |
def convert_files_to_base64(files): | |
""" | |
Convert uploaded files to base64 strings. | |
Args: | |
files (list): List of file paths | |
Returns: | |
list: List of base64-encoded strings | |
""" | |
base64_images = [] | |
for file in files: | |
with open(file, "rb") as image_file: | |
base64_data = base64.b64encode(image_file.read()).decode("utf-8") | |
base64_images.append(base64_data) | |
return base64_images | |
def format_simplified_gpu_data(gpu_data): | |
""" | |
Format GPU data into a simplified, focused display. | |
Args: | |
gpu_data (dict): GPU data in JSON format | |
Returns: | |
str: Formatted GPU data | |
""" | |
if not gpu_data.get("success", False): | |
return f"Error fetching GPU data: {gpu_data.get('error', 'Unknown error')}" | |
output = [] | |
output.append(f"Last updated: {gpu_data.get('timestamp', 'Unknown')}") | |
for i, gpu in enumerate(gpu_data.get("gpus", [])): | |
output.append(f"GPU {gpu.get('index', i)}: {gpu.get('name', 'Unknown')}") | |
output.append(f" Memory: {gpu.get('memory_used', 0):6.0f} MB / {gpu.get('memory_total', 0):6.0f} MB ({gpu.get('memory_utilization', 0):5.1f}%)") | |
output.append(f" Power: {gpu.get('power_draw', 0):5.1f}W / {gpu.get('power_limit', 0):5.1f}W") | |
if 'fan_speed' in gpu: | |
output.append(f" Fan: {gpu.get('fan_speed', 0):5.1f}%") | |
output.append(f" Temp: {gpu.get('temperature', 0):5.1f}°C") | |
output.append("") | |
return "\n".join(output) | |
def update_gpu_status(): | |
""" | |
Fetch and format the current GPU status. | |
Returns: | |
str: Formatted GPU status | |
""" | |
global gpu_data, gpu_tunnel_status | |
if not gpu_tunnel_status["is_running"]: | |
return "GPU monitoring tunnel is not connected." | |
return format_simplified_gpu_data(gpu_data) | |
def get_tunnel_status_message(): | |
""" | |
Return a formatted status message for display in the UI. | |
""" | |
global api_tunnel_status, gpu_tunnel_status, use_fallback, MAX_CONCURRENT | |
api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API" | |
model = get_model_name() | |
api_status_color = "🟢" if (api_tunnel_status["is_running"] and not use_fallback) else "🔴" | |
api_status_text = api_tunnel_status["message"] | |
gpu_status_color = "🟢" if gpu_tunnel_status["is_running"] else "🔴" | |
gpu_status_text = gpu_tunnel_status["message"] | |
return (f"{api_status_color} API Tunnel: {api_status_text}\n" | |
f"{gpu_status_color} GPU Tunnel: {gpu_status_text}\n" | |
f"Current API: {api_mode}\n" | |
f"Current Model: {model}\n" | |
f"Concurrent Requests: {MAX_CONCURRENT}") | |
def get_gpu_json(): | |
""" | |
Return the raw GPU JSON data for debugging. | |
""" | |
global gpu_data | |
return json.dumps(gpu_data, indent=2) | |
def toggle_api(): | |
""" | |
Toggle between local vLLM and Hyperbolic API. | |
""" | |
global use_fallback | |
use_fallback = not use_fallback | |
api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API" | |
model = get_model_name() | |
return f"Switched to {api_mode} using {model}" | |
def update_concurrency(new_value): | |
""" | |
Update the MAX_CONCURRENT value. | |
Args: | |
new_value (str): New concurrency value as string | |
Returns: | |
str: Status message | |
""" | |
global MAX_CONCURRENT | |
try: | |
value = int(new_value) | |
if value < 1: | |
return f"Error: Concurrency must be at least 1. Keeping current value: {MAX_CONCURRENT}" | |
MAX_CONCURRENT = value | |
return f"Concurrency updated to {MAX_CONCURRENT}. You may need to refresh the page for all changes to take effect." | |
except ValueError: | |
return f"Error: Invalid number. Keeping current value: {MAX_CONCURRENT}" | |
# Start SSH tunnels and monitoring threads | |
if __name__ == "__main__": | |
start_ssh_tunnels() | |
monitor_thread = threading.Thread(target=monitor_tunnels, daemon=True) | |
monitor_thread.start() | |
with gr.Blocks(theme="soft") as demo: | |
gr.Markdown("# Multimodal Chat Interface") | |
chatbot = gr.Chatbot( | |
label="Conversation", | |
type="messages", | |
show_copy_button=True, | |
avatar_images=("👤", "🗣️"), | |
height=400 | |
) | |
with gr.Row(): | |
textbox = gr.MultimodalTextbox( | |
file_types=["image", "video"], | |
file_count="multiple", | |
placeholder="Type your message here and/or upload images...", | |
label="Message", | |
show_label=False, | |
scale=9 | |
) | |
submit_btn = gr.Button("Send", size="sm", scale=1) | |
clear_btn = gr.Button("Clear Chat") | |
submit_event = textbox.submit( | |
fn=process_chat, | |
inputs=[textbox, chatbot], | |
outputs=chatbot, | |
concurrency_limit=MAX_CONCURRENT | |
).then( | |
fn=lambda: {"text": "", "files": []}, | |
inputs=None, | |
outputs=textbox | |
) | |
submit_btn.click( | |
fn=process_chat, | |
inputs=[textbox, chatbot], | |
outputs=chatbot, | |
concurrency_limit=MAX_CONCURRENT | |
).then( | |
fn=lambda: {"text": "", "files": []}, | |
inputs=None, | |
outputs=textbox | |
) | |
clear_btn.click(lambda: [], None, chatbot) | |
examples = [] | |
example_images = { | |
"dog_pic.jpg": "What breed is this?", | |
"ghostimg.png": "What's in this image?", | |
"newspaper.png": "Provide a python list of dicts about everything on this page." | |
} | |
for img_name, prompt_text in example_images.items(): | |
img_path = os.path.join(os.path.dirname(__file__), img_name) | |
if os.path.exists(img_path): | |
examples.append([{"text": prompt_text, "files": [img_path]}]) | |
if examples: | |
gr.Examples( | |
examples=examples, | |
inputs=textbox | |
) | |
status_text = gr.Textbox( | |
label="Tunnel and API Status", | |
value=get_tunnel_status_message(), | |
interactive=False | |
) | |
with gr.Accordion("GPU Status", open=False): | |
# Changed from Textbox to HTML component | |
gpu_status = gr.HTML( | |
value=lambda: f"<pre style='font-family: monospace; white-space: pre; overflow: auto;'>{update_gpu_status()}</pre>", | |
every=2 | |
) | |
with gr.Row(): | |
refresh_btn = gr.Button("Refresh Status") | |
toggle_api_btn = gr.Button("Toggle API") | |
refresh_btn.click( | |
fn=get_tunnel_status_message, | |
inputs=None, | |
outputs=status_text | |
) | |
toggle_api_btn.click( | |
fn=toggle_api, | |
inputs=None, | |
outputs=status_text | |
).then( | |
fn=get_tunnel_status_message, | |
inputs=None, | |
outputs=status_text | |
) | |
demo.load( | |
fn=get_tunnel_status_message, | |
inputs=None, | |
outputs=status_text | |
) | |
demo.queue(default_concurrency_limit=MAX_CONCURRENT) | |
demo.launch() |