File size: 4,033 Bytes
4d1a850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c62ce96
4d1a850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5854e3b
8a4fb48
 
4d1a850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import random
import threading
import psutil
import fastapi
import gradio as gr
import uvicorn

from viser_proxy_manager import ViserProxyManager
from vis_st4rtrack import visualize_st4rtrack, load_trajectory_data, log_memory_usage

# Global cache for loaded data
global_data_cache = None

def check_ram_usage(threshold_percent=90):
    """Check if RAM usage is above the threshold.
    
    Args:
        threshold_percent: Maximum RAM usage percentage allowed
        
    Returns:
        bool: True if RAM usage is below threshold, False otherwise
    """
    ram_percent = psutil.virtual_memory().percent
    print(f"Current RAM usage: {ram_percent}%")
    return ram_percent < threshold_percent


def main() -> None:
    # Load data once at startup using the function from vis_st4rtrack.py
    global global_data_cache
    global_data_cache = load_trajectory_data(use_float16=True, max_frames=69, traj_path="480p_train", mask_folder="train", conf_thre_percentile=3)
    
    app = fastapi.FastAPI()
    viser_manager = ViserProxyManager(app)

    # Create a Gradio interface with title, iframe, and buttons
    with gr.Blocks(title="Viser Viewer") as demo:
        # Add the iframe with a border
        iframe_html = gr.HTML("")
        status_text = gr.Markdown("")  # Add status text component

        @demo.load(outputs=[iframe_html, status_text])
        def start_server(request: gr.Request):
            assert request.session_hash is not None
            
            # Check RAM usage before starting visualization
            if not check_ram_usage(threshold_percent=100):
                return """
                <div style="text-align: center; padding: 20px; background-color: #ffeeee; border-radius: 5px;">
                    <h2>⚠️ Server is currently under high load</h2>
                    <p>Please try again later when resources are available.</p>
                </div>
                """, "**System Status:** High memory usage detected. Visualization not loaded to prevent server overload."
            
            viser_manager.start_server(request.session_hash)

            # Use the request's base URL if available
            host = request.headers["host"]

            # Determine protocol (use HTTPS for HuggingFace Spaces or other secure environments)
            protocol = (
                "https"
                if request.headers.get("x-forwarded-proto") == "https"
                else "http"
            )
            
            # Add visualization in a separate thread
            server = viser_manager.get_server(request.session_hash)
            threading.Thread(
                target=visualize_st4rtrack,
                kwargs={
                    "server": server,
                    "use_float16": True,
                    "preloaded_data": global_data_cache,  # Pass the preloaded data
                    "color_code": "jet",
                    "blue_rgb": (0.0, 0.149, 0.463),  # #002676
                    "red_rgb": (0.769, 0.510, 0.055),   # #FDB515
                    "blend_ratio": 0.7,
                    "point_size": 0.0065,
                    "camera_position": (1e-3, 1.7, -0.125),
                },
                daemon=True
            ).start()

            return f"""
            <iframe 
                src="{protocol}://{host}/viser/{request.session_hash}/" 
                width="100%" 
                height="500px" 
                frameborder="0" 
                style="display: block;"
                allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
                loading="lazy"
            ></iframe>
            """, "**System Status:** Visualization loaded successfully."

        @demo.unload
        def stop(request: gr.Request):
            assert request.session_hash is not None
            viser_manager.stop_server(request.session_hash)

    gr.mount_gradio_app(app, demo, "/")
    uvicorn.run(app, host="0.0.0.0", port=7860)


if __name__ == "__main__":
    main()