import random import threading import psutil import fastapi import gradio as gr import uvicorn from viser_proxy_manager import ViserProxyManager from vis_st4rtrack import visualize_st4rtrack, load_trajectory_data, log_memory_usage # Global cache for loaded data global_data_cache = None def check_ram_usage(threshold_percent=90): """Check if RAM usage is above the threshold. Args: threshold_percent: Maximum RAM usage percentage allowed Returns: bool: True if RAM usage is below threshold, False otherwise """ ram_percent = psutil.virtual_memory().percent print(f"Current RAM usage: {ram_percent}%") return ram_percent < threshold_percent def main() -> None: # Load data once at startup using the function from vis_st4rtrack.py global global_data_cache global_data_cache = load_trajectory_data(use_float16=True, max_frames=69, traj_path="480p_train", mask_folder="train", conf_thre_percentile=3) app = fastapi.FastAPI() viser_manager = ViserProxyManager(app) # Create a Gradio interface with title, iframe, and buttons with gr.Blocks(title="Viser Viewer") as demo: # Add the iframe with a border iframe_html = gr.HTML("") status_text = gr.Markdown("") # Add status text component @demo.load(outputs=[iframe_html, status_text]) def start_server(request: gr.Request): assert request.session_hash is not None # Check RAM usage before starting visualization if not check_ram_usage(threshold_percent=100): return """

⚠️ Server is currently under high load

Please try again later when resources are available.

""", "**System Status:** High memory usage detected. Visualization not loaded to prevent server overload." viser_manager.start_server(request.session_hash) # Use the request's base URL if available host = request.headers["host"] # Determine protocol (use HTTPS for HuggingFace Spaces or other secure environments) protocol = ( "https" if request.headers.get("x-forwarded-proto") == "https" else "http" ) # Add visualization in a separate thread server = viser_manager.get_server(request.session_hash) threading.Thread( target=visualize_st4rtrack, kwargs={ "server": server, "use_float16": True, "preloaded_data": global_data_cache, # Pass the preloaded data "color_code": "jet", "blue_rgb": (0.0, 0.149, 0.463), # #002676 "red_rgb": (0.769, 0.510, 0.055), # #FDB515 "blend_ratio": 0.7, "point_size": 0.0065, "camera_position": (1e-3, 1.7, -0.125), }, daemon=True ).start() return f""" """, "**System Status:** Visualization loaded successfully." @demo.unload def stop(request: gr.Request): assert request.session_hash is not None viser_manager.stop_server(request.session_hash) gr.mount_gradio_app(app, demo, "/") uvicorn.run(app, host="0.0.0.0", port=7860) if __name__ == "__main__": main()