akclick1401 commited on
Commit
eedb736
·
verified ·
1 Parent(s): 8d69bd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +392 -8
app.py CHANGED
@@ -1,10 +1,394 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- with gr.Blocks(fill_height=True) as demo:
4
- with gr.Sidebar():
5
- gr.Markdown("# Inference Provider")
6
- gr.Markdown("This Space showcases the Wan-AI/Wan2.1-T2V-14B model, served by the replicate API. Sign in with your Hugging Face account to use this API.")
7
- button = gr.LoginButton("Sign in")
8
- gr.load("models/Wan-AI/Wan2.1-T2V-14B", accept_token=button, provider="replicate")
9
-
10
- demo.launch()
 
1
+ import copy
2
+ import os
3
+ import random
4
+
5
+ os.system('pip install dashscope')
6
  import gradio as gr
7
+ import dashscope
8
+ from dashscope import VideoSynthesis
9
+ from examples import t2v_examples, i2v_examples
10
+ import time
11
+
12
+ DASHSCOPE_API_KEY = os.getenv('DASHSCOPE_API_KEY')
13
+ dashscope.api_key = DASHSCOPE_API_KEY
14
+
15
+ KEEP_SUCCESS_TASK = 3600 * 10
16
+ KEEP_RUNING_TASK = 3600 * 1
17
+ # the total running task number in 1800 seconds
18
+ LIMIT_RUNING_TASK = 10
19
+ LIMIT_HISTORY_RUNING_TASK = 20
20
+ FRESH_TIME = None
21
+
22
+ task_status = {}
23
+
24
+ total_task_info = {
25
+ "total_process_cost": 0,
26
+ "total_complete_task": 0,
27
+ "total_submit": 0,
28
+ "latest_1h_submit_status": {
29
+
30
+ }
31
+ }
32
+
33
+ def get_submit_code():
34
+ submit_code = random.randint(0, 2147483647)
35
+ #
36
+ for sub_c, sub_info in copy.deepcopy(total_task_info["latest_1h_submit_status"]).items():
37
+ if time.time() - sub_info > 3600:
38
+ total_task_info["latest_1h_submit_status"].pop(sub_c)
39
+ return submit_code
40
+ def t2v_generation(prompt, resolution, watermark_wan, seed=-1):
41
+ seed = seed if seed >= 0 else random.randint(0, 2147483647)
42
+ total_task_info["latest_1h_submit_status"][get_submit_code()] = time.time()
43
+ total_task_info["total_submit"] += 1
44
+ if not allow_task_num():
45
+ gr.Info(f"Warning: The number of running tasks is too large, the estimate waiting time is {get_waiting_time('-1')} s.")
46
+ return None, gr.Button(visible=True)
47
+ try:
48
+ rsp = VideoSynthesis.call(model="wanx2.1-t2v-plus", prompt=prompt, seed=seed,
49
+ watermark_wanx=watermark_wan, size=resolution)
50
+ video_url = rsp.output.video_url
51
+ return video_url, gr.Button(visible=True)
52
+ except Exception as e:
53
+ gr.Warning(f"Warning: {e}")
54
+ return None, gr.Button(visible=True)
55
+
56
+
57
+ def t2v_generation_async(prompt, size, watermark_wan, seed=-1):
58
+ print(seed)
59
+ seed = seed if seed >= 0 else random.randint(0, 2147483647)
60
+ total_task_info["latest_1h_submit_status"][get_submit_code()] = time.time()
61
+ total_task_info["total_submit"] += 1
62
+ print(seed)
63
+ if not allow_task_num():
64
+ gr.Info(f"Warning: The number of running tasks is too large, the estimate waiting time is {get_waiting_time('-1')} s.")
65
+ return None, False, gr.Button(visible=True), gr.Button(visible=False), gr.Slider(), gr.Slider()
66
+ try:
67
+ rsp = VideoSynthesis.async_call(model="wanx2.1-t2v-plus",
68
+ prompt=prompt,
69
+ size=size,
70
+ seed=seed,
71
+ watermark_wanx=watermark_wan)
72
+ task_id = rsp.output.task_id
73
+ status = False
74
+ return task_id, status, gr.Button(visible=False), gr.Button(visible=True), get_cost_time(task_id), get_waiting_time(task_id)
75
+ except Exception as e:
76
+ gr.Warning(f"Warning: {e}")
77
+ return None, True, gr.Button(), gr.Button(), gr.Slider(), gr.Slider()
78
+
79
+
80
+ def i2v_generation(prompt, image, watermark_wan, seed=-1):
81
+ seed = seed if seed >= 0 else random.randint(0, 2147483647)
82
+ video_url = None
83
+ try:
84
+ rsp = VideoSynthesis.call(model="wanx2.1-i2v-plus", prompt=prompt, img_url=image,
85
+ seed=seed,
86
+ watermark_wanx=watermark_wan
87
+ )
88
+ video_url = rsp.output.video_url
89
+ except Exception as e:
90
+ gr.Warning(f"Warning: {e}")
91
+ return video_url
92
+
93
+
94
+ def i2v_generation_async(prompt, image, watermark_wan, seed=-1):
95
+ seed = seed if seed >= 0 else random.randint(0, 2147483647)
96
+ total_task_info["latest_1h_submit_status"][get_submit_code()] = time.time()
97
+ total_task_info["total_submit"] += 1
98
+ if not allow_task_num():
99
+ gr.Info(f"Warning: The number of running tasks is too large, the estimate waiting time is {get_waiting_time('-1')} s.")
100
+ return "", None, gr.Button(visible=True), gr.Button(visible=False), gr.Slider(), gr.Slider()
101
+ try:
102
+ rsp = VideoSynthesis.async_call(model="wanx2.1-i2v-plus", prompt=prompt, seed=seed,
103
+ img_url=image, watermark_wanx=watermark_wan)
104
+ print(rsp)
105
+ task_id = rsp.output.task_id
106
+ status = False
107
+ return task_id, status, gr.Button(visible=False), gr.Button(visible=True), get_cost_time(task_id), get_waiting_time(task_id)
108
+ except Exception as e:
109
+ gr.Warning(f"Warning: {e}")
110
+ return "", None, gr.Button(), gr.Button(), gr.Slider(), gr.Slider()
111
+
112
+ def get_result_with_task_id(task_id):
113
+ if task_id == "": return True, None
114
+ try:
115
+ rsp = VideoSynthesis.fetch(task=task_id)
116
+ print(rsp)
117
+ if rsp.output.task_status == "FAILED":
118
+ gr.Info(f"Warning: task running {rsp.output.task_status}")
119
+ status = True
120
+ video_url = None
121
+ else:
122
+ video_url = rsp.output.video_url
123
+ video_url = video_url if video_url != "" else None
124
+ status = video_url is not None
125
+ if status:
126
+ total_task_info["total_complete_task"] += 1
127
+ total_task_info["total_process_cost"] += time.time() - task_status[task_id]["time"]
128
+ print(total_task_info["total_complete_task"], total_task_info["total_process_cost"])
129
+ except:
130
+ video_url = None
131
+ status = False
132
+ return status, None if video_url == "" else video_url
133
+ # return True, "https://dashscope-result-wlcb.oss-cn-wulanchabu.aliyuncs.com/1d/f8/20250220/e7d3f375/ccc590a2-7e90-4d92-84bc-22668db42979.mp4?Expires=1740137152&OSSAccessKeyId=LTAI5tQZd8AEcZX6KZV4G8qL&Signature=i3S3jA5FY6XYfvzZNHnvQiPzZSw%3D"
134
+
135
+
136
+
137
+ def allow_task_num():
138
+ num = 0
139
+ total_num = 0
140
+ for task_id in task_status:
141
+ if not task_status[task_id]["status"] and task_status[task_id]["time"] + 1800 > time.time():
142
+ num += 1
143
+ if not task_status[task_id]["status"]:
144
+ total_num += 1
145
+ return num < LIMIT_RUNING_TASK or total_num < LIMIT_HISTORY_RUNING_TASK
146
+
147
+ def get_waiting_time(task_id):
148
+ # if the num of running task < Limit
149
+ # waiting time = num * 480s
150
+ # task_id not in task_status, return a large number
151
+ # prediction the waiting time
152
+ # avg_cost * latest submit time
153
+ num = 0
154
+ for task_id in task_status:
155
+ if not task_status[task_id]["status"]:
156
+ num += 1
157
+ latest_submit_tasks = len(total_task_info["latest_1h_submit_status"])
158
+ print("latest submit tasks", latest_submit_tasks)
159
+ if task_id in task_status:
160
+ return int(640 - (time.time() - task_status[task_id]["time"]))
161
+ else:
162
+ return int(latest_submit_tasks * (total_task_info["total_process_cost"]/(total_task_info["total_complete_task"]+1)))
163
+
164
+ def online_get_waiting_time(task, t2v_task_id, i2v_task_id):
165
+ task_id = t2v_task_id if task == "t2v" else i2v_task_id
166
+ return get_waiting_time(task_id)
167
+
168
+ def clean_task_status():
169
+ # clean the task over 1800 seconds
170
+ for task_id in copy.deepcopy(task_status):
171
+ if task_id == "": continue
172
+ # finished task, keep 3600 seconds
173
+ if task_status[task_id]["status"]:
174
+ if task_status[task_id]["time"] + KEEP_SUCCESS_TASK < time.time():
175
+ task_status.pop(task_id)
176
+ else:
177
+ # clean the task over 3600 * 2 seconds
178
+ if task_status[task_id]["time"] + KEEP_RUNING_TASK < time.time():
179
+ task_status.pop(task_id)
180
+
181
+
182
+ def get_cost_time(task_id):
183
+ if task_id in task_status and not task_status[task_id]["status"]:
184
+ et = int(time.time() - task_status[task_id]["time"])
185
+ return f"{et:.2f}"
186
+ else:
187
+ return gr.Textbox()
188
+
189
+ def online_get_cost_time(task, t2v_task_id, i2v_task_id):
190
+ task_id = t2v_task_id if task == "t2v" else i2v_task_id
191
+ return get_cost_time(task_id)
192
+
193
+
194
+ def get_process_bar(task, t2v_task_id, i2v_task_id, status):
195
+ task_id = t2v_task_id if task == "t2v" else i2v_task_id
196
+ clean_task_status()
197
+ if task_id not in task_status:
198
+ task_status[task_id] = {
199
+ "value": 0 if not task_id == "" else 100,
200
+ "status": status if not task_id == "" else True,
201
+ "time": time.time(),
202
+ "url": None
203
+ }
204
+ if not task_status[task_id]["status"]:
205
+ # only when > 50% do check status
206
+ if task_status[task_id]["value"] >= 5 and task_status[task_id]["value"] % 5 == 0:
207
+ status, video_url = get_result_with_task_id(task_id)
208
+ else:
209
+ status, video_url = False, None
210
+ task_status[task_id]["status"] = status
211
+ task_status[task_id]["url"] = video_url
212
+ if task_status[task_id]["status"]:
213
+ task_status[task_id]["value"] = 100
214
+ else:
215
+ task_status[task_id]["value"] += 5
216
+ if task_status[task_id]["value"] >= 100 and not task_status[task_id]["status"]:
217
+ task_status[task_id]["value"] = 95
218
+ # print(task_id, task_status[task_id], task_status)
219
+ value = task_status[task_id]["value"]
220
+ return gr.Slider(label=f"({value}%)Generating" if value % 2 == 1 else f"({value}%)Generating.....", value=value)
221
+
222
+
223
+ with gr.Blocks() as demo:
224
+ gr.HTML("""
225
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
226
+ Wan2.1: Open and Advanced Large-Scale Video Generative Models
227
+ </div>
228
+ <div style="text-align: center;">
229
+ <a href="https://github.com/Wan-Video/Wan2.1">Code</a> |
230
+ <a href="https://huggingface.co/Wan-AI">Huggingface</a> |
231
+ <a href="https://modelscope.cn/organization/Wan-AI">Modelscope</a>
232
+ </div>
233
+ <div style="text-align: center; font-size: 16px; font-weight: bold; margin-bottom: 20px;">
234
+ We are excited to announce that Wan's international experience page is officially live, supporting image and video generation, and it's completely free. We welcome you to try it out!
235
+ <a href="https://wan.video/wanxiang/creation">Wan Web</a>
236
+ </div>
237
+ """)
238
+ t2v_task_id = gr.State(value="")
239
+ i2v_task_id = gr.State(value="")
240
+ status = gr.State(value=False)
241
+ task = gr.State(value="t2v")
242
+ with gr.Row():
243
+ with gr.Column():
244
+ with gr.Row():
245
+ with gr.Tabs():
246
+ # Text to Video Tab
247
+ with gr.TabItem("Text to Video") as t2v_tab:
248
+ with gr.Row():
249
+ txt2vid_prompt = gr.Textbox(
250
+ label="Prompt",
251
+ placeholder="Describe the video you want to generate",
252
+ lines=19,
253
+ )
254
+ with gr.Row():
255
+ resolution = gr.Dropdown(
256
+ label="Resolution",
257
+ choices=["1280*720", "960*960", "720*1280", "1088*832", "832*1088"],
258
+ value="1280*720",
259
+ )
260
+ with gr.Row():
261
+ run_t2v_button = gr.Button("Generate Video")
262
+ t2v_refresh_status = gr.Button("Refresh Generating Status", visible=False)
263
+ # Image to Video Tab
264
+ with gr.TabItem("Image to Video") as i2v_tab:
265
+ with gr.Row():
266
+ with gr.Column():
267
+ img2vid_image = gr.Image(
268
+ type="filepath",
269
+ label="Upload Input Image",
270
+ elem_id="image_upload",
271
+ )
272
+ img2vid_prompt = gr.Textbox(
273
+ label="Prompt",
274
+ placeholder="Describe the video you want to generate",
275
+ value="",
276
+ lines=5,
277
+ )
278
+ with gr.Row():
279
+ run_i2v_button = gr.Button("Generate Video")
280
+ i2v_refresh_status = gr.Button("Refresh Generating Status", visible=False)
281
+ with gr.Column():
282
+ with gr.Row():
283
+ result_gallery = gr.Video(label='Generated Video',
284
+ interactive=False,
285
+ height=500)
286
+ with gr.Row():
287
+ watermark_wan = gr.Checkbox(label="Watermark", value=True, visible=True, container=False)
288
+ seed = gr.Number(label="Seed", value=-1, container=True)
289
+ # cost_time = gr.Number(label="Cost Time(secs)", value=online_get_cost_time, interactive=False,
290
+ # every=FRESH_TIME, inputs=[task, t2v_task_id, i2v_task_id], container=True)
291
+ cost_time = gr.Number(label="Cost Time(secs)", value=0, interactive=False, container=True)
292
+ # waiting_time = gr.Number(label="Estimated Waiting Time(secs)", value=online_get_waiting_time, interactive=False,
293
+ # every=FRESH_TIME, inputs=[task, t2v_task_id, i2v_task_id], container=True)
294
+ waiting_time = gr.Number(label="Estimated Waiting Time(secs)", value=0, interactive=False, container=True)
295
+ # process_bar = gr.Slider(show_label=True, label="", value=get_process_bar, maximum=100,
296
+ # interactive=True, every=FRESH_TIME, inputs=[task, t2v_task_id, i2v_task_id, status], container=True)
297
+ process_bar = gr.Slider(show_label=True, label="", value=100, maximum=100,
298
+ interactive=True, container=True)
299
+ with gr.Row():
300
+ gr.Markdown('<span style="color: blue;">Due to automatic refresh of task status causing significant network congestion, please manually click the "Refresh Generating Status" button to check the task status.</span>')
301
+ fake_video = gr.Video(label='Examples', visible=False, interactive=False)
302
+ with gr.Row(visible=True) as t2v_eg:
303
+ gr.Examples(t2v_examples,
304
+ inputs=[txt2vid_prompt, result_gallery],
305
+ outputs=[result_gallery])
306
+
307
+ with gr.Row(visible=False) as i2v_eg:
308
+ gr.Examples(i2v_examples,
309
+ inputs=[img2vid_prompt, img2vid_image, result_gallery],
310
+ outputs=[result_gallery])
311
+
312
+
313
+ def process_change(task_id, task):
314
+ status = task_status.get(task_id, {"status":False})["status"]
315
+ if status:
316
+ video_url = task_status[task_id]["url"]
317
+ ret_t2v_btn = gr.Button(visible=True) if task == 't2v' else gr.Button()
318
+ ret_t2v_status_btn = gr.Button(visible=False) if task == 't2v' else gr.Button()
319
+ ret_i2v_btn = gr.Button(visible=True) if task == 'i2v' else gr.Button()
320
+ ret_i2v_status_btn = gr.Button(visible=False) if task == 'i2v' else gr.Button()
321
+ return gr.Video(value=video_url), ret_t2v_btn, ret_i2v_btn, ret_t2v_status_btn, ret_i2v_status_btn
322
+ return gr.Video(value=None), gr.Button(), gr.Button(), gr.Button(), gr.Button()
323
+
324
+ def online_process_change(task, t2v_task_id, i2v_task_id):
325
+ task_id = t2v_task_id if task == 't2v' else i2v_task_id
326
+ return process_change(task_id, task)
327
+
328
+ process_bar.change(online_process_change, inputs=[task, t2v_task_id, i2v_task_id],
329
+ outputs=[result_gallery, run_t2v_button, run_i2v_button,
330
+ t2v_refresh_status, i2v_refresh_status])
331
+
332
+
333
+ def switch_i2v_tab():
334
+ return gr.Row(visible=False), gr.Row(visible=True), "i2v"
335
+
336
+
337
+ def switch_t2v_tab():
338
+ return gr.Row(visible=True), gr.Row(visible=False), "t2v"
339
+
340
+
341
+ i2v_tab.select(switch_i2v_tab, outputs=[t2v_eg, i2v_eg, task])
342
+ t2v_tab.select(switch_t2v_tab, outputs=[t2v_eg, i2v_eg, task])
343
+
344
+ run_t2v_button.click(
345
+ fn=t2v_generation_async,
346
+ inputs=[txt2vid_prompt, resolution, watermark_wan, seed],
347
+ outputs=[t2v_task_id, status, run_t2v_button, t2v_refresh_status, cost_time, waiting_time],
348
+ )
349
+
350
+ run_i2v_button.click(
351
+ fn=i2v_generation_async,
352
+ inputs=[img2vid_prompt, img2vid_image, watermark_wan, seed],
353
+ outputs=[i2v_task_id, status, run_i2v_button, i2v_refresh_status, cost_time, waiting_time],
354
+ )
355
+
356
+ def status_refresh(task_id, task, status):
357
+ if task_id in task_status and not task_status[task_id]["status"]:
358
+ cost_time = int(time.time() - task_status[task_id]["time"])
359
+ else:
360
+ cost_time = 0
361
+ status, video_url = get_result_with_task_id(task_id)
362
+ if task_id not in task_status:
363
+ task_status[task_id] = {"status": status, "url": video_url, "time": time.time(), "value": 100 if status else 0}
364
+ else:
365
+ task_status[task_id]["status"] = status
366
+ task_status[task_id]["url"] = video_url
367
+ waiting_time = get_waiting_time(task_id)
368
+ value = task_status.get(task_id, {"value": 100})["value"]
369
+ value = max(value, int(cost_time*100/waiting_time))
370
+ task_status[task_id]["value"] = value if value < 100 else 100
371
+ if not video_url == "" and status: value = 100
372
+ process_bar = gr.Slider(label=f"({value}%)Generating" if value % 2 == 1 else f"({value}%)Generating.....", value=value)
373
+ process_change_ret = process_change(task_id, task)
374
+ return *process_change_ret, cost_time, waiting_time, process_bar
375
+
376
+
377
+ t2v_refresh_status.click(
378
+ fn=status_refresh,
379
+ inputs=[t2v_task_id, task, status],
380
+ outputs=[result_gallery, run_t2v_button, run_i2v_button,
381
+ t2v_refresh_status, i2v_refresh_status,
382
+ cost_time, waiting_time, process_bar]
383
+ )
384
+
385
+ i2v_refresh_status.click(
386
+ fn=status_refresh,
387
+ inputs=[i2v_task_id, task, status],
388
+ outputs=[result_gallery, run_t2v_button, run_i2v_button,
389
+ t2v_refresh_status, i2v_refresh_status,
390
+ cost_time, waiting_time, process_bar]
391
+ )
392
 
393
+ #demo.queue(max_size=10)
394
+ demo.launch(ssr_mode=False)