viser_train1 / app.py
Junyi42's picture
Update app.py
c62ce96 verified
import random
import threading
import psutil
import fastapi
import gradio as gr
import uvicorn
from viser_proxy_manager import ViserProxyManager
from vis_st4rtrack import visualize_st4rtrack, load_trajectory_data, log_memory_usage
# Global cache for loaded data
global_data_cache = None
def check_ram_usage(threshold_percent=90):
"""Check if RAM usage is above the threshold.
Args:
threshold_percent: Maximum RAM usage percentage allowed
Returns:
bool: True if RAM usage is below threshold, False otherwise
"""
ram_percent = psutil.virtual_memory().percent
print(f"Current RAM usage: {ram_percent}%")
return ram_percent < threshold_percent
def main() -> None:
# Load data once at startup using the function from vis_st4rtrack.py
global global_data_cache
global_data_cache = load_trajectory_data(use_float16=True, max_frames=69, traj_path="480p_train", mask_folder="train", conf_thre_percentile=3)
app = fastapi.FastAPI()
viser_manager = ViserProxyManager(app)
# Create a Gradio interface with title, iframe, and buttons
with gr.Blocks(title="Viser Viewer") as demo:
# Add the iframe with a border
iframe_html = gr.HTML("")
status_text = gr.Markdown("") # Add status text component
@demo.load(outputs=[iframe_html, status_text])
def start_server(request: gr.Request):
assert request.session_hash is not None
# Check RAM usage before starting visualization
if not check_ram_usage(threshold_percent=100):
return """
<div style="text-align: center; padding: 20px; background-color: #ffeeee; border-radius: 5px;">
<h2>⚠️ Server is currently under high load</h2>
<p>Please try again later when resources are available.</p>
</div>
""", "**System Status:** High memory usage detected. Visualization not loaded to prevent server overload."
viser_manager.start_server(request.session_hash)
# Use the request's base URL if available
host = request.headers["host"]
# Determine protocol (use HTTPS for HuggingFace Spaces or other secure environments)
protocol = (
"https"
if request.headers.get("x-forwarded-proto") == "https"
else "http"
)
# Add visualization in a separate thread
server = viser_manager.get_server(request.session_hash)
threading.Thread(
target=visualize_st4rtrack,
kwargs={
"server": server,
"use_float16": True,
"preloaded_data": global_data_cache, # Pass the preloaded data
"color_code": "jet",
"blue_rgb": (0.0, 0.149, 0.463), # #002676
"red_rgb": (0.769, 0.510, 0.055), # #FDB515
"blend_ratio": 0.7,
"point_size": 0.0065,
"camera_position": (1e-3, 1.7, -0.125),
},
daemon=True
).start()
return f"""
<iframe
src="{protocol}://{host}/viser/{request.session_hash}/"
width="100%"
height="500px"
frameborder="0"
style="display: block;"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
loading="lazy"
></iframe>
""", "**System Status:** Visualization loaded successfully."
@demo.unload
def stop(request: gr.Request):
assert request.session_hash is not None
viser_manager.stop_server(request.session_hash)
gr.mount_gradio_app(app, demo, "/")
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()