Spaces:
Running
Running
import random | |
import threading | |
import psutil | |
import fastapi | |
import gradio as gr | |
import uvicorn | |
from viser_proxy_manager import ViserProxyManager | |
from vis_st4rtrack import visualize_st4rtrack, load_trajectory_data, log_memory_usage | |
# Global cache for loaded data | |
global_data_cache = None | |
def check_ram_usage(threshold_percent=90): | |
"""Check if RAM usage is above the threshold. | |
Args: | |
threshold_percent: Maximum RAM usage percentage allowed | |
Returns: | |
bool: True if RAM usage is below threshold, False otherwise | |
""" | |
ram_percent = psutil.virtual_memory().percent | |
print(f"Current RAM usage: {ram_percent}%") | |
return ram_percent < threshold_percent | |
def main() -> None: | |
# Load data once at startup using the function from vis_st4rtrack.py | |
global global_data_cache | |
global_data_cache = load_trajectory_data(use_float16=True, max_frames=69, traj_path="480p_train", mask_folder="train", conf_thre_percentile=3) | |
app = fastapi.FastAPI() | |
viser_manager = ViserProxyManager(app) | |
# Create a Gradio interface with title, iframe, and buttons | |
with gr.Blocks(title="Viser Viewer") as demo: | |
# Add the iframe with a border | |
iframe_html = gr.HTML("") | |
status_text = gr.Markdown("") # Add status text component | |
def start_server(request: gr.Request): | |
assert request.session_hash is not None | |
# Check RAM usage before starting visualization | |
if not check_ram_usage(threshold_percent=100): | |
return """ | |
<div style="text-align: center; padding: 20px; background-color: #ffeeee; border-radius: 5px;"> | |
<h2>⚠️ Server is currently under high load</h2> | |
<p>Please try again later when resources are available.</p> | |
</div> | |
""", "**System Status:** High memory usage detected. Visualization not loaded to prevent server overload." | |
viser_manager.start_server(request.session_hash) | |
# Use the request's base URL if available | |
host = request.headers["host"] | |
# Determine protocol (use HTTPS for HuggingFace Spaces or other secure environments) | |
protocol = ( | |
"https" | |
if request.headers.get("x-forwarded-proto") == "https" | |
else "http" | |
) | |
# Add visualization in a separate thread | |
server = viser_manager.get_server(request.session_hash) | |
threading.Thread( | |
target=visualize_st4rtrack, | |
kwargs={ | |
"server": server, | |
"use_float16": True, | |
"preloaded_data": global_data_cache, # Pass the preloaded data | |
"color_code": "jet", | |
"blue_rgb": (0.0, 0.149, 0.463), # #002676 | |
"red_rgb": (0.769, 0.510, 0.055), # #FDB515 | |
"blend_ratio": 0.7, | |
"point_size": 0.0065, | |
"camera_position": (1e-3, 1.7, -0.125), | |
}, | |
daemon=True | |
).start() | |
return f""" | |
<iframe | |
src="{protocol}://{host}/viser/{request.session_hash}/" | |
width="100%" | |
height="500px" | |
frameborder="0" | |
style="display: block;" | |
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" | |
loading="lazy" | |
></iframe> | |
""", "**System Status:** Visualization loaded successfully." | |
def stop(request: gr.Request): | |
assert request.session_hash is not None | |
viser_manager.stop_server(request.session_hash) | |
gr.mount_gradio_app(app, demo, "/") | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |
if __name__ == "__main__": | |
main() | |