File size: 3,863 Bytes
4e424ea
ca753f0
4e424ea
7bedcdd
4e424ea
 
 
 
 
 
 
 
ca753f0
f0f4c78
935512c
 
 
 
3217fc0
935512c
 
 
3217fc0
935512c
3217fc0
935512c
 
 
 
4e424ea
f0f4c78
4e424ea
 
 
 
 
f0f4c78
4e424ea
 
 
f0f4c78
 
 
 
 
 
 
 
4e424ea
ca753f0
6c641ac
 
 
935512c
3217fc0
935512c
 
 
 
f20624c
935512c
 
f20624c
 
935512c
 
 
 
 
 
 
 
 
 
 
3217fc0
 
 
 
 
935512c
ca753f0
f20624c
935512c
f0f4c78
3217fc0
f20624c
935512c
4f2bf09
f0f4c78
4e424ea
d701afa
4e424ea
 
f0f4c78
4e424ea
 
 
 
c81f025
4e424ea
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import re 
import subprocess
from tqdm import tqdm
from huggingface_hub import snapshot_download

#Download model
snapshot_download(
    repo_id = "Wan-AI/Wan2.1-T2V-1.3B",
    local_dir = "./Wan2.1-T2V-1.3B"
)

def infer(prompt, progress=gr.Progress(track_tqdm=True)):

    total_process_steps = 12
    irrelevant_steps = 3
    relevant_steps = total_process_steps - irrelevant_steps  # 9 steps

    # Create an overall process bar for the 9 relevant steps.
    overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1, dynamic_ncols=True, leave=True)
    processed_steps = 0

    # Regex to extract the INFO message from each log line.
    info_pattern = re.compile(r"\[.*?\]\s+INFO:\s+(.*)")
    # Regex to capture progress lines from video generation (like " 10%|...| 5/50").
    progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
    
    gen_progress_bar = None

    command = [
        "python", "-u", "-m", "generate",  # using -u for unbuffered output and omitting .py extension
        "--task", "t2v-1.3B",
        "--size", "832*480",
        "--ckpt_dir", "./Wan2.1-T2V-1.3B",
        "--sample_shift", "8",
        "--sample_guide_scale", "6",
        "--prompt", prompt,
        "--save_file", "generated_video.mp4"
    ]

    # Start the process with unbuffered output and combine stdout and stderr.
    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1  # line-buffered
    )

    for line in iter(process.stdout.readline, ''):
        stripped_line = line.strip()
        if not stripped_line:
            continue

        # Check if this line is a progress update for video generation.
        progress_match = progress_pattern.search(stripped_line)
        if progress_match:
            current = int(progress_match.group(2))
            total = int(progress_match.group(3))
            if gen_progress_bar is None:
                gen_progress_bar = tqdm(total=total, desc="Video Generation", position=0, dynamic_ncols=True, leave=True)
            # Update the generation progress bar by the difference.
            gen_progress_bar.update(current - gen_progress_bar.n)
            gen_progress_bar.refresh()
            continue  # Skip further processing of this line.

        # Check for an INFO log line.
        info_match = info_pattern.search(stripped_line)
        if info_match:
            msg = info_match.group(1)
            # Skip the first three INFO messages.
            if processed_steps < irrelevant_steps:
                processed_steps += 1
            else:
                overall_bar.update(1)
                # Compute the current percentage.
                percentage = (overall_bar.n / overall_bar.total) * 100
                # Set the description to include both the percentage and the current info title.
                overall_bar.set_description(f"Overall Process - {percentage:.0f}% | {msg}")
            # Write the log line as well.
            tqdm.write(stripped_line)
        else:
            tqdm.write(stripped_line)

    process.wait()
    if gen_progress_bar is not None:
        gen_progress_bar.close()
    overall_bar.close()

    if process.returncode == 0:
        print("Command executed successfully.")
        return "generated_video.mp4"
    else:
        print("Error executing command.")
        raise Exception("Error executing command")

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Wan 2.1")
        prompt = gr.Textbox(label="Prompt")
        submit_btn = gr.Button("Submit")
        video_res = gr.Video(label="Generated Video")

    submit_btn.click(
        fn = infer,
        inputs = [prompt],
        outputs = [video_res]
    )

demo.queue().launch(show_error=True, show_api=False, ssr_mode=False)