rahul7star commited on
Commit
711b244
·
verified ·
1 Parent(s): 942fdd0

Update simple_app.py

Browse files
Files changed (1) hide show
  1. simple_app.py +105 -37
simple_app.py CHANGED
@@ -1,22 +1,48 @@
1
  import gradio as gr
2
  import re
3
  import subprocess
 
4
  import select
 
5
  from huggingface_hub import snapshot_download
 
6
 
7
- # Download model (for demonstration, adjust based on actual model needs)
 
 
 
8
  snapshot_download(
9
  repo_id="Wan-AI/Wan2.1-T2V-1.3B",
10
  local_dir="./Wan2.1-T2V-1.3B"
11
  )
12
 
13
- # Function to generate video
14
  def infer(prompt, progress=gr.Progress(track_tqdm=True)):
15
- # Reduced progress output and simplified structure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  command = [
17
- "python", "-u", "-m", "generate", # Using unbuffered output
18
  "--task", "t2v-1.3B",
19
- "--size", "832*480", # You can try reducing resolution further for CPU
20
  "--ckpt_dir", "./Wan2.1-T2V-1.3B",
21
  "--sample_shift", "8",
22
  "--sample_guide_scale", "6",
@@ -24,19 +50,14 @@ def infer(prompt, progress=gr.Progress(track_tqdm=True)):
24
  "--save_file", "generated_video.mp4"
25
  ]
26
 
27
- # Run the model inference in a subprocess
28
- process = subprocess.Popen(command,
29
- stdout=subprocess.PIPE,
30
- stderr=subprocess.PIPE, # Capture stderr for error messages
31
- text=True,
32
  bufsize=1)
33
 
34
- # Monitor progress with a minimal progress bar
35
- progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
36
- video_progress_bar = None
37
- overall_steps = 0
38
-
39
  while True:
 
40
  rlist, _, _ = select.select([process.stdout], [], [], 0.04)
41
  if rlist:
42
  line = process.stdout.readline()
@@ -46,50 +67,97 @@ def infer(prompt, progress=gr.Progress(track_tqdm=True)):
46
  if not stripped_line:
47
  continue
48
 
49
- # Check for video generation progress
50
  progress_match = progress_pattern.search(stripped_line)
51
  if progress_match:
 
 
 
 
 
 
 
 
 
 
52
  current = int(progress_match.group(2))
53
  total = int(progress_match.group(3))
54
  if video_progress_bar is None:
55
- video_progress_bar = gr.Progress()
56
- video_progress_bar.update(current / total)
57
- video_progress_bar.update(current / total)
 
 
 
 
 
 
 
58
  continue
59
 
60
- # Process info messages (simplified)
61
  if "INFO:" in stripped_line:
62
- overall_steps += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  continue
64
  else:
65
  print(stripped_line)
66
-
 
 
 
 
 
 
 
67
  if process.poll() is not None:
68
  break
69
 
70
- # Drain any remaining output from stderr
71
- stderr_output = process.stderr.read().strip()
72
- if stderr_output:
73
- print(f"Error output:\n{stderr_output}")
74
-
75
- # Clean up and finalize the progress bar
76
  process.wait()
77
- if video_progress_bar:
78
  video_progress_bar.close()
79
-
80
- # Check if the process finished successfully
 
 
81
  if process.returncode == 0:
 
82
  return "generated_video.mp4"
83
  else:
84
- print(f"Process failed with return code {process.returncode}")
85
- raise Exception(f"Error executing command: {stderr_output}")
86
 
87
- # Gradio UI
88
  with gr.Blocks() as demo:
89
  with gr.Column():
90
- gr.Markdown("# Wan 2.1 1.3B Video Generation")
 
91
  prompt = gr.Textbox(label="Prompt")
92
- submit_btn = gr.Button("Generate Video")
93
  video_res = gr.Video(label="Generated Video")
94
 
95
  submit_btn.click(
@@ -98,4 +166,4 @@ with gr.Blocks() as demo:
98
  outputs=[video_res]
99
  )
100
 
101
- demo.queue().launch(show_error=True, show_api=False)
 
1
  import gradio as gr
2
  import re
3
  import subprocess
4
+ import time
5
  import select
6
+ from tqdm import tqdm
7
  from huggingface_hub import snapshot_download
8
+ import torch
9
 
10
+ # Force the device to CPU
11
+ device = torch.device("cpu")
12
+
13
+ # Download model
14
  snapshot_download(
15
  repo_id="Wan-AI/Wan2.1-T2V-1.3B",
16
  local_dir="./Wan2.1-T2V-1.3B"
17
  )
18
 
 
19
  def infer(prompt, progress=gr.Progress(track_tqdm=True)):
20
+ # Configuration:
21
+ total_process_steps = 11 # Total INFO messages expected
22
+ irrelevant_steps = 4 # First 4 INFO messages are ignored
23
+ relevant_steps = total_process_steps - irrelevant_steps # 7 overall steps
24
+
25
+ # Create overall progress bar (Level 1)
26
+ overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1,
27
+ ncols=120, dynamic_ncols=False, leave=True)
28
+ processed_steps = 0
29
+
30
+ # Regex for video generation progress (Level 3)
31
+ progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
32
+ video_progress_bar = None
33
+
34
+ # Variables for sub-step progress bar (Level 2)
35
+ # Now using 1000 ticks to represent 40 seconds (each tick = 40 ms)
36
+ sub_bar = None
37
+ sub_ticks = 0
38
+ sub_tick_total = 1500
39
+ video_phase = False
40
+
41
+ # Command to run the video generation
42
  command = [
43
+ "python", "-u", "-m", "generate", # using -u for unbuffered output
44
  "--task", "t2v-1.3B",
45
+ "--size", "832*480",
46
  "--ckpt_dir", "./Wan2.1-T2V-1.3B",
47
  "--sample_shift", "8",
48
  "--sample_guide_scale", "6",
 
50
  "--save_file", "generated_video.mp4"
51
  ]
52
 
53
+ process = subprocess.Popen(command,
54
+ stdout=subprocess.PIPE,
55
+ stderr=subprocess.STDOUT,
56
+ text=True,
 
57
  bufsize=1)
58
 
 
 
 
 
 
59
  while True:
60
+ # Poll stdout with a 40ms timeout.
61
  rlist, _, _ = select.select([process.stdout], [], [], 0.04)
62
  if rlist:
63
  line = process.stdout.readline()
 
67
  if not stripped_line:
68
  continue
69
 
70
+ # Check for video generation progress (Level 3)
71
  progress_match = progress_pattern.search(stripped_line)
72
  if progress_match:
73
+ # If a sub-step bar is active, finish it before entering video phase.
74
+ if sub_bar is not None:
75
+ if sub_ticks < sub_tick_total:
76
+ sub_bar.update(sub_tick_total - sub_ticks)
77
+ sub_bar.close()
78
+ overall_bar.update(1)
79
+ overall_bar.refresh()
80
+ sub_bar = None
81
+ sub_ticks = 0
82
+ video_phase = True
83
  current = int(progress_match.group(2))
84
  total = int(progress_match.group(3))
85
  if video_progress_bar is None:
86
+ video_progress_bar = tqdm(total=total, desc="Video Generation", position=0,
87
+ ncols=120, dynamic_ncols=True, leave=True)
88
+ video_progress_bar.update(current - video_progress_bar.n)
89
+ video_progress_bar.refresh()
90
+ if video_progress_bar.n >= video_progress_bar.total:
91
+ video_phase = False
92
+ overall_bar.update(1)
93
+ overall_bar.refresh()
94
+ video_progress_bar.close()
95
+ video_progress_bar = None
96
  continue
97
 
98
+ # Process INFO messages (Level 2 sub-step)
99
  if "INFO:" in stripped_line:
100
+ parts = stripped_line.split("INFO:", 1)
101
+ msg = parts[1].strip() if len(parts) > 1 else ""
102
+ print(stripped_line) # Log the message
103
+
104
+ # For the first 4 INFO messages, simply count them.
105
+ if processed_steps < irrelevant_steps:
106
+ processed_steps += 1
107
+ continue
108
+ else:
109
+ # A new relevant INFO message has arrived.
110
+ # If a sub-bar exists (whether full or not), finish it now.
111
+ if sub_bar is not None:
112
+ if sub_ticks < sub_tick_total:
113
+ sub_bar.update(sub_tick_total - sub_ticks)
114
+ sub_bar.close()
115
+ overall_bar.update(1)
116
+ overall_bar.refresh()
117
+ sub_bar = None
118
+ sub_ticks = 0
119
+ # Start a new sub-step bar for the current INFO message.
120
+ sub_bar = tqdm(total=sub_tick_total, desc=msg, position=2,
121
+ ncols=120, dynamic_ncols=False, leave=True)
122
+ sub_ticks = 0
123
  continue
124
  else:
125
  print(stripped_line)
126
+ else:
127
+ # No new data within 40ms.
128
+ if sub_bar is not None:
129
+ if sub_ticks < sub_tick_total:
130
+ sub_bar.update(1)
131
+ sub_ticks += 1
132
+ sub_bar.refresh()
133
+ # If full (40 seconds reached), do not advance overall step—just remain waiting.
134
  if process.poll() is not None:
135
  break
136
 
137
+ # Drain any remaining output.
138
+ for line in process.stdout:
139
+ print(line.strip())
 
 
 
140
  process.wait()
141
+ if video_progress_bar is not None:
142
  video_progress_bar.close()
143
+ if sub_bar is not None:
144
+ sub_bar.close()
145
+ overall_bar.close()
146
+
147
  if process.returncode == 0:
148
+ print("Command executed successfully.")
149
  return "generated_video.mp4"
150
  else:
151
+ print("Error executing command.")
152
+ raise Exception("Error executing command")
153
 
154
+ # Gradio UI to trigger inference
155
  with gr.Blocks() as demo:
156
  with gr.Column():
157
+ gr.Markdown("# Wan 2.1 1.3B")
158
+ gr.Markdown("Enjoy this simple working UI, duplicate the space to skip the queue :)")
159
  prompt = gr.Textbox(label="Prompt")
160
+ submit_btn = gr.Button("Submit")
161
  video_res = gr.Video(label="Generated Video")
162
 
163
  submit_btn.click(
 
166
  outputs=[video_res]
167
  )
168
 
169
+ demo.queue().launch(show_error=True, show_api=False, ssr_mode=False)