Intercept-Intelligence commited on
Commit
5d40c04
Β·
verified Β·
1 Parent(s): 0281435

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +388 -63
app.py CHANGED
@@ -1,69 +1,394 @@
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs:
6
- https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
7
- """
8
- client = InferenceClient("damo-vilab/modelscope-text-to-video-synthesis")
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- # NOTE: Video models don't usually use "streaming" generation, so we'll just call once
29
- payload = {
30
- "inputs": message,
31
- "parameters": {
32
- "max_new_tokens": max_tokens,
33
- "temperature": temperature,
34
- "top_p": top_p,
35
- }
36
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Post directly to the model
39
- response = client.post(json=payload)
 
 
 
 
 
 
 
 
40
 
41
- video_url = response.get("video", None)
42
 
43
- if video_url:
44
- yield video_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  else:
46
- yield "Failed to generate video."
47
-
48
- """
49
- For information on how to customize the ChatInterface, peruse the gradio docs:
50
- https://www.gradio.app/docs/chatinterface
51
- """
52
- demo = gr.ChatInterface(
53
- respond,
54
- additional_inputs=[
55
- gr.Textbox(value="You are generating a creative video.", label="System message"),
56
- gr.Slider(minimum=1, maximum=1000, value=250, step=1, label="Max new tokens"),
57
- gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
58
- gr.Slider(
59
- minimum=0.1,
60
- maximum=1.0,
61
- value=0.9,
62
- step=0.05,
63
- label="Top-p (nucleus sampling)",
64
- ),
65
- ],
66
- )
67
-
68
- if __name__ == "__main__":
69
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+
5
+ os.system('pip install dashscope')
6
  import gradio as gr
7
+ import dashscope
8
+ from dashscope import VideoSynthesis
9
+ from examples import t2v_examples, i2v_examples
10
+ import time
11
+
12
+ DASHSCOPE_API_KEY = os.getenv('DASHSCOPE_API_KEY')
13
+ dashscope.api_key = DASHSCOPE_API_KEY
14
+
15
+ KEEP_SUCCESS_TASK = 3600 * 10
16
+ KEEP_RUNING_TASK = 3600 * 1
17
+ # the total running task number in 1800 seconds
18
+ LIMIT_RUNING_TASK = 10
19
+ LIMIT_HISTORY_RUNING_TASK = 20
20
+ FRESH_TIME = None
21
+
22
+ task_status = {}
23
+
24
+ total_task_info = {
25
+ "total_process_cost": 0,
26
+ "total_complete_task": 0,
27
+ "total_submit": 0,
28
+ "latest_1h_submit_status": {
29
+
 
 
 
 
 
 
 
 
 
 
 
30
  }
31
+ }
32
+
33
+ def get_submit_code():
34
+ submit_code = random.randint(0, 2147483647)
35
+ #
36
+ for sub_c, sub_info in copy.deepcopy(total_task_info["latest_1h_submit_status"]).items():
37
+ if time.time() - sub_info > 3600:
38
+ total_task_info["latest_1h_submit_status"].pop(sub_c)
39
+ return submit_code
40
+ def t2v_generation(prompt, resolution, watermark_wan, seed=-1):
41
+ seed = seed if seed >= 0 else random.randint(0, 2147483647)
42
+ total_task_info["latest_1h_submit_status"][get_submit_code()] = time.time()
43
+ total_task_info["total_submit"] += 1
44
+ if not allow_task_num():
45
+ gr.Info(f"Warning: The number of running tasks is too large, the estimate waiting time is {get_waiting_time('-1')} s.")
46
+ return None, gr.Button(visible=True)
47
+ try:
48
+ rsp = VideoSynthesis.call(model="wanx2.1-t2v-plus", prompt=prompt, seed=seed,
49
+ watermark_wanx=watermark_wan, size=resolution)
50
+ video_url = rsp.output.video_url
51
+ return video_url, gr.Button(visible=True)
52
+ except Exception as e:
53
+ gr.Warning(f"Warning: {e}")
54
+ return None, gr.Button(visible=True)
55
+
56
+
57
+ def t2v_generation_async(prompt, size, watermark_wan, seed=-1):
58
+ print(seed)
59
+ seed = seed if seed >= 0 else random.randint(0, 2147483647)
60
+ total_task_info["latest_1h_submit_status"][get_submit_code()] = time.time()
61
+ total_task_info["total_submit"] += 1
62
+ print(seed)
63
+ if not allow_task_num():
64
+ gr.Info(f"Warning: The number of running tasks is too large, the estimate waiting time is {get_waiting_time('-1')} s.")
65
+ return None, False, gr.Button(visible=True), gr.Button(visible=False), gr.Slider(), gr.Slider()
66
+ try:
67
+ rsp = VideoSynthesis.async_call(model="wanx2.1-t2v-plus",
68
+ prompt=prompt,
69
+ size=size,
70
+ seed=seed,
71
+ watermark_wanx=watermark_wan)
72
+ task_id = rsp.output.task_id
73
+ status = False
74
+ return task_id, status, gr.Button(visible=False), gr.Button(visible=True), get_cost_time(task_id), get_waiting_time(task_id)
75
+ except Exception as e:
76
+ gr.Warning(f"Warning: {e}")
77
+ return None, True, gr.Button(), gr.Button(), gr.Slider(), gr.Slider()
78
+
79
+
80
+ def i2v_generation(prompt, image, watermark_wan, seed=-1):
81
+ seed = seed if seed >= 0 else random.randint(0, 2147483647)
82
+ video_url = None
83
+ try:
84
+ rsp = VideoSynthesis.call(model="wanx2.1-i2v-plus", prompt=prompt, img_url=image,
85
+ seed=seed,
86
+ watermark_wanx=watermark_wan
87
+ )
88
+ video_url = rsp.output.video_url
89
+ except Exception as e:
90
+ gr.Warning(f"Warning: {e}")
91
+ return video_url
92
+
93
+
94
+ def i2v_generation_async(prompt, image, watermark_wan, seed=-1):
95
+ seed = seed if seed >= 0 else random.randint(0, 2147483647)
96
+ total_task_info["latest_1h_submit_status"][get_submit_code()] = time.time()
97
+ total_task_info["total_submit"] += 1
98
+ if not allow_task_num():
99
+ gr.Info(f"Warning: The number of running tasks is too large, the estimate waiting time is {get_waiting_time('-1')} s.")
100
+ return "", None, gr.Button(visible=True), gr.Button(visible=False), gr.Slider(), gr.Slider()
101
+ try:
102
+ rsp = VideoSynthesis.async_call(model="wanx2.1-i2v-plus", prompt=prompt, seed=seed,
103
+ img_url=image, watermark_wanx=watermark_wan)
104
+ print(rsp)
105
+ task_id = rsp.output.task_id
106
+ status = False
107
+ return task_id, status, gr.Button(visible=False), gr.Button(visible=True), get_cost_time(task_id), get_waiting_time(task_id)
108
+ except Exception as e:
109
+ gr.Warning(f"Warning: {e}")
110
+ return "", None, gr.Button(), gr.Button(), gr.Slider(), gr.Slider()
111
+
112
+ def get_result_with_task_id(task_id):
113
+ if task_id == "": return True, None
114
+ try:
115
+ rsp = VideoSynthesis.fetch(task=task_id)
116
+ print(rsp)
117
+ if rsp.output.task_status == "FAILED":
118
+ gr.Info(f"Warning: task running {rsp.output.task_status}")
119
+ status = True
120
+ video_url = None
121
+ else:
122
+ video_url = rsp.output.video_url
123
+ video_url = video_url if video_url != "" else None
124
+ status = video_url is not None
125
+ if status:
126
+ total_task_info["total_complete_task"] += 1
127
+ total_task_info["total_process_cost"] += time.time() - task_status[task_id]["time"]
128
+ print(total_task_info["total_complete_task"], total_task_info["total_process_cost"])
129
+ except:
130
+ video_url = None
131
+ status = False
132
+ return status, None if video_url == "" else video_url
133
+ # return True, "https://dashscope-result-wlcb.oss-cn-wulanchabu.aliyuncs.com/1d/f8/20250220/e7d3f375/ccc590a2-7e90-4d92-84bc-22668db42979.mp4?Expires=1740137152&OSSAccessKeyId=LTAI5tQZd8AEcZX6KZV4G8qL&Signature=i3S3jA5FY6XYfvzZNHnvQiPzZSw%3D"
134
+
135
+
136
+
137
+ def allow_task_num():
138
+ num = 0
139
+ total_num = 0
140
+ for task_id in task_status:
141
+ if not task_status[task_id]["status"] and task_status[task_id]["time"] + 1800 > time.time():
142
+ num += 1
143
+ if not task_status[task_id]["status"]:
144
+ total_num += 1
145
+ return num < LIMIT_RUNING_TASK or total_num < LIMIT_HISTORY_RUNING_TASK
146
+
147
+ def get_waiting_time(task_id):
148
+ # if the num of running task < Limit
149
+ # waiting time = num * 480s
150
+ # task_id not in task_status, return a large number
151
+ # prediction the waiting time
152
+ # avg_cost * latest submit time
153
+ num = 0
154
+ for task_id in task_status:
155
+ if not task_status[task_id]["status"]:
156
+ num += 1
157
+ latest_submit_tasks = len(total_task_info["latest_1h_submit_status"])
158
+ print("latest submit tasks", latest_submit_tasks)
159
+ if task_id in task_status:
160
+ return int(640 - (time.time() - task_status[task_id]["time"]))
161
+ else:
162
+ return int(latest_submit_tasks * (total_task_info["total_process_cost"]/(total_task_info["total_complete_task"]+1)))
163
+
164
+ def online_get_waiting_time(task, t2v_task_id, i2v_task_id):
165
+ task_id = t2v_task_id if task == "t2v" else i2v_task_id
166
+ return get_waiting_time(task_id)
167
+
168
+ def clean_task_status():
169
+ # clean the task over 1800 seconds
170
+ for task_id in copy.deepcopy(task_status):
171
+ if task_id == "": continue
172
+ # finished task, keep 3600 seconds
173
+ if task_status[task_id]["status"]:
174
+ if task_status[task_id]["time"] + KEEP_SUCCESS_TASK < time.time():
175
+ task_status.pop(task_id)
176
+ else:
177
+ # clean the task over 3600 * 2 seconds
178
+ if task_status[task_id]["time"] + KEEP_RUNING_TASK < time.time():
179
+ task_status.pop(task_id)
180
+
181
 
182
+ def get_cost_time(task_id):
183
+ if task_id in task_status and not task_status[task_id]["status"]:
184
+ et = int(time.time() - task_status[task_id]["time"])
185
+ return f"{et:.2f}"
186
+ else:
187
+ return gr.Textbox()
188
+
189
+ def online_get_cost_time(task, t2v_task_id, i2v_task_id):
190
+ task_id = t2v_task_id if task == "t2v" else i2v_task_id
191
+ return get_cost_time(task_id)
192
 
 
193
 
194
+ def get_process_bar(task, t2v_task_id, i2v_task_id, status):
195
+ task_id = t2v_task_id if task == "t2v" else i2v_task_id
196
+ clean_task_status()
197
+ if task_id not in task_status:
198
+ task_status[task_id] = {
199
+ "value": 0 if not task_id == "" else 100,
200
+ "status": status if not task_id == "" else True,
201
+ "time": time.time(),
202
+ "url": None
203
+ }
204
+ if not task_status[task_id]["status"]:
205
+ # only when > 50% do check status
206
+ if task_status[task_id]["value"] >= 5 and task_status[task_id]["value"] % 5 == 0:
207
+ status, video_url = get_result_with_task_id(task_id)
208
+ else:
209
+ status, video_url = False, None
210
+ task_status[task_id]["status"] = status
211
+ task_status[task_id]["url"] = video_url
212
+ if task_status[task_id]["status"]:
213
+ task_status[task_id]["value"] = 100
214
  else:
215
+ task_status[task_id]["value"] += 5
216
+ if task_status[task_id]["value"] >= 100 and not task_status[task_id]["status"]:
217
+ task_status[task_id]["value"] = 95
218
+ # print(task_id, task_status[task_id], task_status)
219
+ value = task_status[task_id]["value"]
220
+ return gr.Slider(label=f"({value}%)Generating" if value % 2 == 1 else f"({value}%)Generating.....", value=value)
221
+
222
+
223
+ with gr.Blocks() as demo:
224
+ gr.HTML("""
225
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
226
+ Wan2.1: Open and Advanced Large-Scale Video Generative Models
227
+ </div>
228
+ <div style="text-align: center;">
229
+ <a href="https://github.com/Wan-Video/Wan2.1">Code</a> |
230
+ <a href="https://huggingface.co/Wan-AI">Huggingface</a> |
231
+ <a href="https://modelscope.cn/organization/Wan-AI">Modelscope</a>
232
+ </div>
233
+ <div style="text-align: center; font-size: 16px; font-weight: bold; margin-bottom: 20px;">
234
+ We are excited to announce that Wan's international experience page is officially live, supporting image and video generation, and it's completely free. We welcome you to try it out!
235
+ <a href="https://wan.video/wanxiang/creation">Wan Web</a>
236
+ </div>
237
+ """)
238
+ t2v_task_id = gr.State(value="")
239
+ i2v_task_id = gr.State(value="")
240
+ status = gr.State(value=False)
241
+ task = gr.State(value="t2v")
242
+ with gr.Row():
243
+ with gr.Column():
244
+ with gr.Row():
245
+ with gr.Tabs():
246
+ # Text to Video Tab
247
+ with gr.TabItem("Text to Video") as t2v_tab:
248
+ with gr.Row():
249
+ txt2vid_prompt = gr.Textbox(
250
+ label="Prompt",
251
+ placeholder="Describe the video you want to generate",
252
+ lines=19,
253
+ )
254
+ with gr.Row():
255
+ resolution = gr.Dropdown(
256
+ label="Resolution",
257
+ choices=["1280*720", "960*960", "720*1280", "1088*832", "832*1088"],
258
+ value="1280*720",
259
+ )
260
+ with gr.Row():
261
+ run_t2v_button = gr.Button("Generate Video")
262
+ t2v_refresh_status = gr.Button("Refresh Generating Status", visible=False)
263
+ # Image to Video Tab
264
+ with gr.TabItem("Image to Video") as i2v_tab:
265
+ with gr.Row():
266
+ with gr.Column():
267
+ img2vid_image = gr.Image(
268
+ type="filepath",
269
+ label="Upload Input Image",
270
+ elem_id="image_upload",
271
+ )
272
+ img2vid_prompt = gr.Textbox(
273
+ label="Prompt",
274
+ placeholder="Describe the video you want to generate",
275
+ value="",
276
+ lines=5,
277
+ )
278
+ with gr.Row():
279
+ run_i2v_button = gr.Button("Generate Video")
280
+ i2v_refresh_status = gr.Button("Refresh Generating Status", visible=False)
281
+ with gr.Column():
282
+ with gr.Row():
283
+ result_gallery = gr.Video(label='Generated Video',
284
+ interactive=False,
285
+ height=500)
286
+ with gr.Row():
287
+ watermark_wan = gr.Checkbox(label="Watermark", value=True, visible=True, container=False)
288
+ seed = gr.Number(label="Seed", value=-1, container=True)
289
+ # cost_time = gr.Number(label="Cost Time(secs)", value=online_get_cost_time, interactive=False,
290
+ # every=FRESH_TIME, inputs=[task, t2v_task_id, i2v_task_id], container=True)
291
+ cost_time = gr.Number(label="Cost Time(secs)", value=0, interactive=False, container=True)
292
+ # waiting_time = gr.Number(label="Estimated Waiting Time(secs)", value=online_get_waiting_time, interactive=False,
293
+ # every=FRESH_TIME, inputs=[task, t2v_task_id, i2v_task_id], container=True)
294
+ waiting_time = gr.Number(label="Estimated Waiting Time(secs)", value=0, interactive=False, container=True)
295
+ # process_bar = gr.Slider(show_label=True, label="", value=get_process_bar, maximum=100,
296
+ # interactive=True, every=FRESH_TIME, inputs=[task, t2v_task_id, i2v_task_id, status], container=True)
297
+ process_bar = gr.Slider(show_label=True, label="", value=100, maximum=100,
298
+ interactive=True, container=True)
299
+ with gr.Row():
300
+ gr.Markdown('<span style="color: blue;">Due to automatic refresh of task status causing significant network congestion, please manually click the "Refresh Generating Status" button to check the task status.</span>')
301
+ fake_video = gr.Video(label='Examples', visible=False, interactive=False)
302
+ with gr.Row(visible=True) as t2v_eg:
303
+ gr.Examples(t2v_examples,
304
+ inputs=[txt2vid_prompt, result_gallery],
305
+ outputs=[result_gallery])
306
+
307
+ with gr.Row(visible=False) as i2v_eg:
308
+ gr.Examples(i2v_examples,
309
+ inputs=[img2vid_prompt, img2vid_image, result_gallery],
310
+ outputs=[result_gallery])
311
+
312
+
313
+ def process_change(task_id, task):
314
+ status = task_status.get(task_id, {"status":False})["status"]
315
+ if status:
316
+ video_url = task_status[task_id]["url"]
317
+ ret_t2v_btn = gr.Button(visible=True) if task == 't2v' else gr.Button()
318
+ ret_t2v_status_btn = gr.Button(visible=False) if task == 't2v' else gr.Button()
319
+ ret_i2v_btn = gr.Button(visible=True) if task == 'i2v' else gr.Button()
320
+ ret_i2v_status_btn = gr.Button(visible=False) if task == 'i2v' else gr.Button()
321
+ return gr.Video(value=video_url), ret_t2v_btn, ret_i2v_btn, ret_t2v_status_btn, ret_i2v_status_btn
322
+ return gr.Video(value=None), gr.Button(), gr.Button(), gr.Button(), gr.Button()
323
+
324
+ def online_process_change(task, t2v_task_id, i2v_task_id):
325
+ task_id = t2v_task_id if task == 't2v' else i2v_task_id
326
+ return process_change(task_id, task)
327
+
328
+ process_bar.change(online_process_change, inputs=[task, t2v_task_id, i2v_task_id],
329
+ outputs=[result_gallery, run_t2v_button, run_i2v_button,
330
+ t2v_refresh_status, i2v_refresh_status])
331
+
332
+
333
+ def switch_i2v_tab():
334
+ return gr.Row(visible=False), gr.Row(visible=True), "i2v"
335
+
336
+
337
+ def switch_t2v_tab():
338
+ return gr.Row(visible=True), gr.Row(visible=False), "t2v"
339
+
340
+
341
+ i2v_tab.select(switch_i2v_tab, outputs=[t2v_eg, i2v_eg, task])
342
+ t2v_tab.select(switch_t2v_tab, outputs=[t2v_eg, i2v_eg, task])
343
+
344
+ run_t2v_button.click(
345
+ fn=t2v_generation_async,
346
+ inputs=[txt2vid_prompt, resolution, watermark_wan, seed],
347
+ outputs=[t2v_task_id, status, run_t2v_button, t2v_refresh_status, cost_time, waiting_time],
348
+ )
349
+
350
+ run_i2v_button.click(
351
+ fn=i2v_generation_async,
352
+ inputs=[img2vid_prompt, img2vid_image, watermark_wan, seed],
353
+ outputs=[i2v_task_id, status, run_i2v_button, i2v_refresh_status, cost_time, waiting_time],
354
+ )
355
+
356
+ def status_refresh(task_id, task, status):
357
+ if task_id in task_status and not task_status[task_id]["status"]:
358
+ cost_time = int(time.time() - task_status[task_id]["time"])
359
+ else:
360
+ cost_time = 0
361
+ status, video_url = get_result_with_task_id(task_id)
362
+ if task_id not in task_status:
363
+ task_status[task_id] = {"status": status, "url": video_url, "time": time.time(), "value": 100 if status else 0}
364
+ else:
365
+ task_status[task_id]["status"] = status
366
+ task_status[task_id]["url"] = video_url
367
+ waiting_time = get_waiting_time(task_id)
368
+ value = task_status.get(task_id, {"value": 100})["value"]
369
+ value = max(value, int(cost_time*100/waiting_time))
370
+ task_status[task_id]["value"] = value if value < 100 else 100
371
+ if not video_url == "" and status: value = 100
372
+ process_bar = gr.Slider(label=f"({value}%)Generating" if value % 2 == 1 else f"({value}%)Generating.....", value=value)
373
+ process_change_ret = process_change(task_id, task)
374
+ return *process_change_ret, cost_time, waiting_time, process_bar
375
+
376
+
377
+ t2v_refresh_status.click(
378
+ fn=status_refresh,
379
+ inputs=[t2v_task_id, task, status],
380
+ outputs=[result_gallery, run_t2v_button, run_i2v_button,
381
+ t2v_refresh_status, i2v_refresh_status,
382
+ cost_time, waiting_time, process_bar]
383
+ )
384
+
385
+ i2v_refresh_status.click(
386
+ fn=status_refresh,
387
+ inputs=[i2v_task_id, task, status],
388
+ outputs=[result_gallery, run_t2v_button, run_i2v_button,
389
+ t2v_refresh_status, i2v_refresh_status,
390
+ cost_time, waiting_time, process_bar]
391
+ )
392
+
393
+ #demo.queue(max_size=10)
394
+ demo.launch(ssr_mode=False)