File size: 3,431 Bytes
17f753b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608013f
17f753b
 
da4015e
17f753b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from spaces import GPU
import gradio as gr
import torch
import os
import time
from torchvision import models
from joblib import load
from extractor.visualise_vit_layer import VitGenerator
from relax_vqa import get_deep_feature, process_video_feature, process_patches, get_frame_patches, flow_to_rgb, merge_fragments, concatenate_features
from extractor.vf_extract import process_video_residual
from model_regression import Mlp, preprocess_data
from demo_test_gpu import evaluate_video_quality, load_model


@GPU
def run_relax_vqa(video_path, is_finetune, framerate, video_type):
    if not os.path.exists(video_path):
        return "❌ No video uploaded or the uploaded file has expired. Please upload again."

    config = {
        'is_finetune': is_finetune,
        'framerate': framerate,
        'video_type': video_type,
        'save_path': 'model/',
        'train_data_name': 'lsvq_train',
        'select_criteria': 'byrmse',
        'video_path': video_path,
        'video_name': os.path.splitext(os.path.basename(video_path))[0]
    }

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    resnet50 = models.resnet50(pretrained=True).to(device)
    vit = VitGenerator('vit_base', 16, device, evaluate=True, random=False, verbose=False)
    model_mlp = load_model(config, device)

    try:
        score, runtime = evaluate_video_quality(config, resnet50, vit, model_mlp, device)
        return f"Predicted Quality Score: {score:.4f} (in {runtime:.2f}s)"
    except Exception as e:
        return f"❌ Error: {str(e)}"
    finally:
        if "gradio" in video_path and os.path.exists(video_path):
            os.remove(video_path)


def toggle_dataset_visibility(is_finetune):
    return gr.update(visible=is_finetune)


with gr.Blocks() as demo:
    gr.Markdown("# 🎬 ReLaX-VQA Online Demo")
    gr.Markdown(
        "Upload a short video and get the predicted perceptual quality score using the ReLaX-VQA model. "
        "You can try our test video from the "
        "<a href='https://huggingface.co/spaces/xinyiW915/ReLaX-VQA/blob/main/ugc_original_videos/5636101558_540p.mp4' target='_blank'>demo video</a> "
        "(fps = 24, dataset = konvid_1k).<br><br>"
        "⚙️ This demo is currently running on <strong>Hugging Face ZeroGPU Space</strong>: Dynamic resources (NVIDIA A100)."
    )

    with gr.Row():
        with gr.Column(scale=2):
            video_input = gr.Video(label="Upload a Video (e.g. mp4)")
            framerate_slider = gr.Slider(label="Source Video Framerate (fps)", minimum=1, maximum=60, step=1, value=24)
            is_finetune_checkbox = gr.Checkbox(label="Use Finetuning?", value=False)
            dataset_dropdown = gr.Dropdown(
                label="Source Video Dataset for Finetuning",
                choices=["konvid_1k", "youtube_ugc", "live_vqc", "cvd_2014"],
                value="konvid_1k",
                visible=False
            )
            run_button = gr.Button("Run Prediction")
        with gr.Column(scale=1):
            output_box = gr.Textbox(label="Predicted Quality Score", lines=5)

    is_finetune_checkbox.change(
        fn=toggle_dataset_visibility,
        inputs=is_finetune_checkbox,
        outputs=dataset_dropdown
    )

    run_button.click(
        fn=run_relax_vqa,
        inputs=[video_input, is_finetune_checkbox, framerate_slider, dataset_dropdown],
        outputs=output_box
    )

demo.launch()