lisonallen commited on
Commit
d4dcfc5
·
1 Parent(s): ffb7037

增强GPU错误处理和添加CPU回退模式,解决ZeroGPU worker error

Browse files
Files changed (1) hide show
  1. app.py +566 -315
app.py CHANGED
@@ -105,13 +105,39 @@ import math
105
  # 检查是否在Hugging Face Space环境中
106
  IN_HF_SPACE = os.environ.get('SPACE_ID') is not None
107
 
 
 
 
 
 
108
  # 如果在Hugging Face Space中,导入spaces模块
109
  if IN_HF_SPACE:
110
  try:
111
  import spaces
112
  print("在Hugging Face Space环境中运行,已导入spaces模块")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  except ImportError:
114
  print("未能导入spaces模块,可能不在Hugging Face Space环境中")
 
115
 
116
  from PIL import Image
117
  from diffusers import AutoencoderKLHunyuanVideo
@@ -149,95 +175,194 @@ if not IN_HF_SPACE:
149
  else:
150
  # 在Spaces环境中使用默认值
151
  print("在Spaces环境中使用默认内存设置")
152
- free_mem_gb = 60.0 # 默认在Spaces中使用较高的值
153
- high_vram = True
154
- print(f'High-VRAM Mode: {high_vram}')
 
 
 
 
 
 
 
 
 
 
155
 
156
  # 使用models变量存储全局模型引用
157
  models = {}
 
158
 
159
  # 使用加载模型的函数
160
  def load_models():
161
- global models
162
 
163
- print("开始加载模型...")
164
-
165
- # 加载模型
166
- text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
167
- text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
168
- tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
169
- tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
170
- vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()
171
-
172
- feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
173
- image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()
174
-
175
- transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePackI2V_HY', torch_dtype=torch.bfloat16).cpu()
176
-
177
- vae.eval()
178
- text_encoder.eval()
179
- text_encoder_2.eval()
180
- image_encoder.eval()
181
- transformer.eval()
182
-
183
- if not high_vram:
184
- vae.enable_slicing()
185
- vae.enable_tiling()
186
-
187
- transformer.high_quality_fp32_output_for_inference = True
188
- print('transformer.high_quality_fp32_output_for_inference = True')
189
-
190
- transformer.to(dtype=torch.bfloat16)
191
- vae.to(dtype=torch.float16)
192
- image_encoder.to(dtype=torch.float16)
193
- text_encoder.to(dtype=torch.float16)
194
- text_encoder_2.to(dtype=torch.float16)
195
-
196
- vae.requires_grad_(False)
197
- text_encoder.requires_grad_(False)
198
- text_encoder_2.requires_grad_(False)
199
- image_encoder.requires_grad_(False)
200
- transformer.requires_grad_(False)
201
-
202
- if torch.cuda.is_available():
203
- if not high_vram:
204
- # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
205
- DynamicSwapInstaller.install_model(transformer, device=gpu)
206
- DynamicSwapInstaller.install_model(text_encoder, device=gpu)
207
- else:
208
- text_encoder.to(gpu)
209
- text_encoder_2.to(gpu)
210
- image_encoder.to(gpu)
211
- vae.to(gpu)
212
- transformer.to(gpu)
213
 
214
- # 保存到全局变量
215
- models = {
216
- 'text_encoder': text_encoder,
217
- 'text_encoder_2': text_encoder_2,
218
- 'tokenizer': tokenizer,
219
- 'tokenizer_2': tokenizer_2,
220
- 'vae': vae,
221
- 'feature_extractor': feature_extractor,
222
- 'image_encoder': image_encoder,
223
- 'transformer': transformer
224
- }
225
 
226
- return models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
 
229
  # 使用Hugging Face Spaces GPU装饰器
230
- if IN_HF_SPACE and 'spaces' in globals():
231
- @spaces.GPU
232
- def initialize_models():
233
- """在@spaces.GPU装饰器内初始化模型"""
234
- return load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
 
237
  # 以下函数内部会延迟获取模型
238
  def get_models():
239
  """获取模型,如果尚未加载则加载模型"""
240
- global models
241
 
242
  # 添加模型加载锁,防止并发加载
243
  model_loading_key = "__model_loading__"
@@ -248,20 +373,37 @@ def get_models():
248
  print("模型正在加载中,等待...")
249
  # 等待模型加载完成
250
  import time
 
251
  while not models and model_loading_key in globals():
252
  time.sleep(0.5)
253
- return models
 
 
 
 
 
 
254
 
255
  try:
256
  # 设置加载标记
257
  globals()[model_loading_key] = True
258
 
259
- if IN_HF_SPACE and 'spaces' in globals():
260
- print("使用@spaces.GPU装饰器加载模型")
261
- models = initialize_models()
 
 
 
 
 
262
  else:
263
  print("直接加载模型")
264
- load_models()
 
 
 
 
 
265
  finally:
266
  # 无论成功与否,都移除加载标记
267
  if model_loading_key in globals():
@@ -275,16 +417,46 @@ stream = AsyncStream()
275
 
276
  @torch.no_grad()
277
  def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache):
 
 
 
278
  # 获取模型
279
- models = get_models()
280
- text_encoder = models['text_encoder']
281
- text_encoder_2 = models['text_encoder_2']
282
- tokenizer = models['tokenizer']
283
- tokenizer_2 = models['tokenizer_2']
284
- vae = models['vae']
285
- feature_extractor = models['feature_extractor']
286
- image_encoder = models['image_encoder']
287
- transformer = models['transformer']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
290
  total_latent_sections = int(max(round(total_latent_sections), 1))
@@ -299,79 +471,136 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
299
 
300
  try:
301
  # Clean GPU
302
- if not high_vram:
303
- unload_complete_models(
304
- text_encoder, text_encoder_2, image_encoder, vae, transformer
305
- )
 
 
 
 
306
 
307
  # Text encoding
308
-
309
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
310
 
311
- if not high_vram:
312
- fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
313
- load_model_as_complete(text_encoder_2, target_device=gpu)
 
314
 
315
- llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
316
 
317
- if cfg == 1:
318
- llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
319
- else:
320
- llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
321
 
322
- llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
323
- llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
 
 
 
 
 
 
 
324
 
325
  # Processing input image
326
-
327
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
328
 
329
- H, W, C = input_image.shape
330
- height, width = find_nearest_bucket(H, W, resolution=640)
331
- input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
 
 
 
 
 
 
 
332
 
333
- Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
334
 
335
- input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
336
- input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
 
 
 
 
 
 
 
337
 
338
  # VAE encoding
339
-
340
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
341
 
342
- if not high_vram:
343
- load_model_as_complete(vae, target_device=gpu)
 
344
 
345
- start_latent = vae_encode(input_image_pt, vae)
 
 
 
 
 
 
 
346
 
347
  # CLIP Vision
348
-
349
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
350
 
351
- if not high_vram:
352
- load_model_as_complete(image_encoder, target_device=gpu)
 
353
 
354
- image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
355
- image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
 
 
 
 
 
 
 
356
 
357
  # Dtype
358
-
359
- llama_vec = llama_vec.to(transformer.dtype)
360
- llama_vec_n = llama_vec_n.to(transformer.dtype)
361
- clip_l_pooler = clip_l_pooler.to(transformer.dtype)
362
- clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
363
- image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
 
 
 
 
 
 
 
364
 
365
  # Sampling
366
-
367
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
368
 
369
  rnd = torch.Generator("cpu").manual_seed(seed)
370
  num_frames = latent_window_size * 4 - 3
371
 
372
- history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32).cpu()
373
- history_pixels = None
374
- total_generated_latent_frames = 0
 
 
 
 
 
 
 
 
375
 
376
  latent_paddings = reversed(range(total_latent_sections))
377
 
@@ -383,6 +612,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
383
  latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
384
 
385
  for latent_padding in latent_paddings:
 
386
  is_last_section = latent_padding == 0
387
  latent_padding_size = latent_padding * latent_window_size
388
 
@@ -401,42 +631,70 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
401
 
402
  print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}')
403
 
404
- indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
405
- clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
406
- clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
 
407
 
408
- clean_latents_pre = start_latent.to(history_latents)
409
- clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
410
- clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
 
 
 
 
 
 
 
 
411
 
412
- if not high_vram:
413
- unload_complete_models()
414
- move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
 
 
 
 
415
 
416
- if use_teacache:
417
- transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
 
 
 
 
 
418
  else:
419
  transformer.initialize_teacache(enable_teacache=False)
420
 
421
  def callback(d):
422
- preview = d['denoised']
423
- preview = vae_decode_fake(preview)
 
 
 
 
424
 
425
- preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
426
- preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
427
 
428
- if stream.input_queue.top() == 'end':
429
- stream.output_queue.push(('end', None))
430
- raise KeyboardInterrupt('User ends the task.')
431
 
432
- current_step = d['i'] + 1
433
- percentage = int(100.0 * current_step / steps)
434
- hint = f'Sampling {current_step}/{steps}'
435
- desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). The video is being extended now ...'
436
- stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
 
 
 
437
  return
438
 
439
  try:
 
 
 
440
  generated_latents = sample_hunyuan(
441
  transformer=transformer,
442
  sampler='unipc',
@@ -455,8 +713,8 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
455
  negative_prompt_embeds=llama_vec_n,
456
  negative_prompt_embeds_mask=llama_attention_mask_n,
457
  negative_prompt_poolers=clip_l_pooler_n,
458
- device=gpu,
459
- dtype=torch.bfloat16,
460
  image_embeddings=image_encoder_last_hidden_state,
461
  latent_indices=latent_indices,
462
  clean_latents=clean_latents,
@@ -467,6 +725,8 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
467
  clean_latent_4x_indices=clean_latent_4x_indices,
468
  callback=callback,
469
  )
 
 
470
  except Exception as e:
471
  print(f"采样过程中出错: {e}")
472
  traceback.print_exc()
@@ -474,23 +734,57 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
474
  # 如果已经有生成的视频,返回最后生成的视频
475
  if last_output_filename:
476
  stream.output_queue.push(('file', last_output_filename))
 
 
 
 
 
 
 
 
477
 
478
  stream.output_queue.push(('end', None))
479
  return
480
 
481
- if is_last_section:
482
- generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
 
483
 
484
- total_generated_latent_frames += int(generated_latents.shape[2])
485
- history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
 
 
 
 
 
 
 
 
 
 
486
 
487
- if not high_vram:
488
- offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
489
- load_model_as_complete(vae, target_device=gpu)
 
 
 
 
490
 
491
- real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
 
 
 
 
 
 
 
 
492
 
493
  try:
 
 
 
494
  if history_pixels is None:
495
  history_pixels = vae_decode(real_history_latents, vae).cpu()
496
  else:
@@ -500,12 +794,19 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
500
  current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
501
  history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
502
 
503
- if not high_vram:
504
- unload_complete_models()
 
 
 
 
 
505
 
506
  output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
507
 
 
508
  save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
 
509
 
510
  print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
511
 
@@ -519,6 +820,10 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
519
  if last_output_filename:
520
  stream.output_queue.push(('file', last_output_filename))
521
 
 
 
 
 
522
  # 尝试继续下一次迭代
523
  continue
524
 
@@ -528,7 +833,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
528
  print(f"处理过程中出现错误: {e}")
529
  traceback.print_exc()
530
 
531
- if not high_vram:
532
  try:
533
  unload_complete_models(
534
  text_encoder, text_encoder_2, image_encoder, vae, transformer
@@ -539,6 +844,10 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
539
  # 如果已经有生成的视频,返回最后生成的视频
540
  if last_output_filename:
541
  stream.output_queue.push(('file', last_output_filename))
 
 
 
 
542
 
543
  # 确保总是返回end信号
544
  stream.output_queue.push(('end', None))
@@ -563,6 +872,7 @@ if IN_HF_SPACE and 'spaces' in globals():
563
 
564
  output_filename = None
565
  prev_output_filename = None
 
566
 
567
  # 持续检查worker的输出
568
  while True:
@@ -577,13 +887,23 @@ if IN_HF_SPACE and 'spaces' in globals():
577
  if flag == 'progress':
578
  preview, desc, html = data
579
  yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
 
 
 
 
 
580
 
581
  if flag == 'end':
582
  # 如果有最后的视频文件,确保返回
583
  if output_filename is None and prev_output_filename is not None:
584
  output_filename = prev_output_filename
585
-
586
- yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
 
587
  break
588
  except Exception as e:
589
  print(f"处理输出时出错: {e}")
@@ -594,52 +914,10 @@ if IN_HF_SPACE and 'spaces' in globals():
594
 
595
  # 如果有部分生成的视频,返回
596
  if prev_output_filename:
597
- # 创建双语部分视频生成消息
598
- partial_video_msg = f"""
599
- <div id="partial-video-container">
600
- <div class="msg-en" data-lang="en">Processing error, but partial video has been generated</div>
601
- <div class="msg-zh" data-lang="zh">处理过程中出现错误,但已生成部分视频</div>
602
- </div>
603
- <script>
604
- // 根据当前语言显示相应的消息
605
- (function() {{
606
- const container = document.getElementById('partial-video-container');
607
- if (container) {{
608
- const currentLang = window.currentLang || 'en'; // 默认英语
609
- const msgs = container.querySelectorAll('[data-lang]');
610
- msgs.forEach(msg => {{
611
- msg.style.display = msg.getAttribute('data-lang') === currentLang ? 'block' : 'none';
612
- }});
613
- }}
614
- }})();
615
- </script>
616
- """
617
- yield prev_output_filename, gr.update(visible=False), gr.update(), partial_video_msg, gr.update(interactive=True), gr.update(interactive=False)
618
  else:
619
- # 创建双语错误消息
620
- error_msg = str(e)
621
- en_msg = f"Processing error: {error_msg}"
622
- zh_msg = f"处理过程中出现错误: {error_msg}"
623
-
624
- error_html = f"""
625
- <div id="error-msg-container">
626
- <div class="error-msg-en" data-lang="en">{en_msg}</div>
627
- <div class="error-msg-zh" data-lang="zh">{zh_msg}</div>
628
- </div>
629
- <script>
630
- // 根据当前语言显示相应的错误消息
631
- (function() {{
632
- const errorContainer = document.getElementById('error-msg-container');
633
- if (errorContainer) {{
634
- const currentLang = window.currentLang || 'en'; // 默认英语
635
- const errMsgs = errorContainer.querySelectorAll('[data-lang]');
636
- errMsgs.forEach(msg => {{
637
- msg.style.display = msg.getAttribute('data-lang') === currentLang ? 'block' : 'none';
638
- }});
639
- }}
640
- }})();
641
- </script>
642
- """
643
  yield None, gr.update(visible=False), gr.update(), error_html, gr.update(interactive=True), gr.update(interactive=False)
644
  break
645
 
@@ -647,47 +925,9 @@ if IN_HF_SPACE and 'spaces' in globals():
647
  print(f"启动处理时出错: {e}")
648
  traceback.print_exc()
649
  error_msg = str(e)
650
- user_friendly_msg = f'处理过程出错: {error_msg}'
651
-
652
- # 提供更友好的中英文双语错误信息
653
- en_msg = ""
654
- zh_msg = ""
655
 
656
- if "模型下载超时" in error_msg or "网络连接不稳定" in error_msg or "ReadTimeoutError" in error_msg or "ConnectionError" in error_msg:
657
- en_msg = "Network connection is unstable, model download timed out. Please try again later."
658
- zh_msg = "网络连接不稳定,模型下载超时。请稍后再试。"
659
- elif "GPU内存不足" in error_msg or "CUDA out of memory" in error_msg or "OutOfMemoryError" in error_msg:
660
- en_msg = "GPU memory insufficient, please try increasing GPU memory preservation value or reduce video length."
661
- zh_msg = "GPU内存不足,请尝试增加GPU推理保留内存值或降低视频长度。"
662
- elif "无法加载模型" in error_msg:
663
- en_msg = "Failed to load model, possibly due to network issues or high server load. Please try again later."
664
- zh_msg = "模型加载失败,可能是网络问题或服务器负载过高。请稍后再试。"
665
- else:
666
- en_msg = f"Processing error: {error_msg}"
667
- zh_msg = f"处理过程出错: {error_msg}"
668
-
669
- # 创建双语错误消息HTML
670
- bilingual_error = f"""
671
- <div id="error-container">
672
- <div class="error-msg-en" data-lang="en">{en_msg}</div>
673
- <div class="error-msg-zh" data-lang="zh">{zh_msg}</div>
674
- </div>
675
- <script>
676
- // 根据当前语言显示相应的错误消息
677
- (function() {{
678
- const errorContainer = document.getElementById('error-container');
679
- if (errorContainer) {{
680
- const currentLang = window.currentLang || 'en'; // 默认英语
681
- const errMsgs = errorContainer.querySelectorAll('[data-lang]');
682
- errMsgs.forEach(msg => {{
683
- msg.style.display = msg.getAttribute('data-lang') === currentLang ? 'block' : 'none';
684
- }});
685
- }}
686
- }})();
687
- </script>
688
- """
689
-
690
- yield None, gr.update(visible=False), gr.update(), bilingual_error, gr.update(interactive=True), gr.update(interactive=False)
691
 
692
  process = process_with_gpu
693
  else:
@@ -706,6 +946,7 @@ else:
706
 
707
  output_filename = None
708
  prev_output_filename = None
 
709
 
710
  # 持续检查worker的输出
711
  while True:
@@ -720,13 +961,23 @@ else:
720
  if flag == 'progress':
721
  preview, desc, html = data
722
  yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
 
 
 
 
 
723
 
724
  if flag == 'end':
725
  # 如果有最后的视频文件,确保返回
726
  if output_filename is None and prev_output_filename is not None:
727
  output_filename = prev_output_filename
728
-
729
- yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
 
730
  break
731
  except Exception as e:
732
  print(f"处理输出时出错: {e}")
@@ -737,74 +988,20 @@ else:
737
 
738
  # 如果有部分生成的视频,返回
739
  if prev_output_filename:
740
- # 创建中断消息的双语支持
741
- interrupt_msg = f"""
742
- <div id="interrupt-container">
743
- <div class="msg-en" data-lang="en">Processing was interrupted, but partial video has been generated</div>
744
- <div class="msg-zh" data-lang="zh">处理过程中断,但已生成部分视频</div>
745
- </div>
746
- <script>
747
- // 根据当前语言显示相应的消息
748
- (function() {{
749
- const container = document.getElementById('interrupt-container');
750
- if (container) {{
751
- const currentLang = window.currentLang || 'en'; // 默认英语
752
- const msgs = container.querySelectorAll('[data-lang]');
753
- msgs.forEach(msg => {{
754
- msg.style.display = msg.getAttribute('data-lang') === currentLang ? 'block' : 'none';
755
- }});
756
- }}
757
- }})();
758
- </script>
759
- """
760
- yield prev_output_filename, gr.update(visible=False), gr.update(), interrupt_msg, gr.update(interactive=True), gr.update(interactive=False)
761
- break
762
 
763
  except Exception as e:
764
  print(f"启动处理时出错: {e}")
765
  traceback.print_exc()
766
  error_msg = str(e)
767
- user_friendly_msg = f'处理过程出错: {error_msg}'
768
-
769
- # 提供更友好的中英文双语错误信息
770
- en_msg = ""
771
- zh_msg = ""
772
 
773
- if "模型下载超时" in error_msg or "网络连接不稳定" in error_msg or "ReadTimeoutError" in error_msg or "ConnectionError" in error_msg:
774
- en_msg = "Network connection is unstable, model download timed out. Please try again later."
775
- zh_msg = "网络连接不稳定,模型下载超时。请稍后再试。"
776
- elif "GPU内存不足" in error_msg or "CUDA out of memory" in error_msg or "OutOfMemoryError" in error_msg:
777
- en_msg = "GPU memory insufficient, please try increasing GPU memory preservation value or reduce video length."
778
- zh_msg = "GPU内存不足,请尝试增加GPU推理保留内存值或降低视频长度。"
779
- elif "无法加载模型" in error_msg:
780
- en_msg = "Failed to load model, possibly due to network issues or high server load. Please try again later."
781
- zh_msg = "模型加载失败,可能是网络问题或服务器负载过高。请稍后再试。"
782
- else:
783
- en_msg = f"Processing error: {error_msg}"
784
- zh_msg = f"处理过程出错: {error_msg}"
785
-
786
- # 创建双语错误消息HTML
787
- bilingual_error = f"""
788
- <div id="error-container">
789
- <div class="error-msg-en" data-lang="en">{en_msg}</div>
790
- <div class="error-msg-zh" data-lang="zh">{zh_msg}</div>
791
- </div>
792
- <script>
793
- // 根据当前语言显示相应的错误消息
794
- (function() {{
795
- const errorContainer = document.getElementById('error-container');
796
- if (errorContainer) {{
797
- const currentLang = window.currentLang || 'en'; // 默认英语
798
- const errMsgs = errorContainer.querySelectorAll('[data-lang]');
799
- errMsgs.forEach(msg => {{
800
- msg.style.display = msg.getAttribute('data-lang') === currentLang ? 'block' : 'none';
801
- }});
802
- }}
803
- }})();
804
- </script>
805
- """
806
-
807
- yield None, gr.update(visible=False), gr.update(), bilingual_error, gr.update(interactive=True), gr.update(interactive=False)
808
 
809
 
810
  def end_process():
@@ -1268,4 +1465,58 @@ with block:
1268
  end_button.click(fn=end_process)
1269
 
1270
 
1271
- block.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  # 检查是否在Hugging Face Space环境中
106
  IN_HF_SPACE = os.environ.get('SPACE_ID') is not None
107
 
108
+ # 添加变量跟踪GPU可用性
109
+ GPU_AVAILABLE = False
110
+ GPU_INITIALIZED = False
111
+ last_update_time = time.time()
112
+
113
  # 如果在Hugging Face Space中,导入spaces模块
114
  if IN_HF_SPACE:
115
  try:
116
  import spaces
117
  print("在Hugging Face Space环境中运行,已导入spaces模块")
118
+
119
+ # 检查GPU可用性
120
+ try:
121
+ GPU_AVAILABLE = torch.cuda.is_available()
122
+ print(f"GPU available: {GPU_AVAILABLE}")
123
+ if GPU_AVAILABLE:
124
+ print(f"GPU device name: {torch.cuda.get_device_name(0)}")
125
+ print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9} GB")
126
+
127
+ # 尝试进行小型GPU操作,确认GPU实际可用
128
+ test_tensor = torch.zeros(1, device='cuda')
129
+ test_tensor = test_tensor + 1
130
+ del test_tensor
131
+ print("成功进行GPU测试操作")
132
+ else:
133
+ print("警告: CUDA报告可用,但未检测到GPU设备")
134
+ except Exception as e:
135
+ GPU_AVAILABLE = False
136
+ print(f"检查GPU时出错: {e}")
137
+ print("将使用CPU模式运行")
138
  except ImportError:
139
  print("未能导入spaces模块,可能不在Hugging Face Space环境中")
140
+ GPU_AVAILABLE = torch.cuda.is_available()
141
 
142
  from PIL import Image
143
  from diffusers import AutoencoderKLHunyuanVideo
 
175
  else:
176
  # 在Spaces环境中使用默认值
177
  print("在Spaces环境中使用默认内存设置")
178
+ try:
179
+ if GPU_AVAILABLE:
180
+ free_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 * 0.9 # 使用90%的GPU内存
181
+ high_vram = free_mem_gb > 10 # 更保守的条件
182
+ else:
183
+ free_mem_gb = 6.0 # 默认值
184
+ high_vram = False
185
+ except Exception as e:
186
+ print(f"获取GPU内存时出错: {e}")
187
+ free_mem_gb = 6.0 # 默认值
188
+ high_vram = False
189
+
190
+ print(f'GPU内存: {free_mem_gb:.2f} GB, High-VRAM Mode: {high_vram}')
191
 
192
  # 使用models变量存储全局模型引用
193
  models = {}
194
+ cpu_fallback_mode = not GPU_AVAILABLE # 如果GPU不可用,使用CPU回退模式
195
 
196
  # 使用加载模型的函数
197
  def load_models():
198
+ global models, cpu_fallback_mode, GPU_INITIALIZED
199
 
200
+ if GPU_INITIALIZED:
201
+ print("模型已加载,跳过重复加载")
202
+ return models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ print("开始加载模型...")
 
 
 
 
 
 
 
 
 
 
205
 
206
+ try:
207
+ # 设置设备,根据GPU可用性确定
208
+ device = 'cuda' if GPU_AVAILABLE and not cpu_fallback_mode else 'cpu'
209
+ model_device = 'cpu' # 初始加载到CPU
210
+
211
+ # 降低精度以节省内存
212
+ dtype = torch.float16 if GPU_AVAILABLE else torch.float32
213
+ transformer_dtype = torch.bfloat16 if GPU_AVAILABLE else torch.float32
214
+
215
+ print(f"使用设备: {device}, 模型精度: {dtype}, Transformer精度: {transformer_dtype}")
216
+
217
+ # 加载模型
218
+ try:
219
+ text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=dtype).to(model_device)
220
+ text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=dtype).to(model_device)
221
+ tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
222
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
223
+ vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=dtype).to(model_device)
224
+
225
+ feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
226
+ image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=dtype).to(model_device)
227
+
228
+ transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePackI2V_HY', torch_dtype=transformer_dtype).to(model_device)
229
+
230
+ print("成功加载所有模型")
231
+ except Exception as e:
232
+ print(f"加载模型时出错: {e}")
233
+ print("尝试降低精度重新加载...")
234
+
235
+ # 降低精度重试
236
+ dtype = torch.float32
237
+ transformer_dtype = torch.float32
238
+ cpu_fallback_mode = True
239
+
240
+ text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=dtype).to('cpu')
241
+ text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=dtype).to('cpu')
242
+ tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
243
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
244
+ vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=dtype).to('cpu')
245
+
246
+ feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
247
+ image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=dtype).to('cpu')
248
+
249
+ transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePackI2V_HY', torch_dtype=transformer_dtype).to('cpu')
250
+
251
+ print("使用CPU模式成功加载所有模型")
252
+
253
+ vae.eval()
254
+ text_encoder.eval()
255
+ text_encoder_2.eval()
256
+ image_encoder.eval()
257
+ transformer.eval()
258
+
259
+ if not high_vram or cpu_fallback_mode:
260
+ vae.enable_slicing()
261
+ vae.enable_tiling()
262
+
263
+ transformer.high_quality_fp32_output_for_inference = True
264
+ print('transformer.high_quality_fp32_output_for_inference = True')
265
+
266
+ # 设置模型精度
267
+ if not cpu_fallback_mode:
268
+ transformer.to(dtype=transformer_dtype)
269
+ vae.to(dtype=dtype)
270
+ image_encoder.to(dtype=dtype)
271
+ text_encoder.to(dtype=dtype)
272
+ text_encoder_2.to(dtype=dtype)
273
+
274
+ vae.requires_grad_(False)
275
+ text_encoder.requires_grad_(False)
276
+ text_encoder_2.requires_grad_(False)
277
+ image_encoder.requires_grad_(False)
278
+ transformer.requires_grad_(False)
279
+
280
+ if torch.cuda.is_available() and not cpu_fallback_mode:
281
+ try:
282
+ if not high_vram:
283
+ # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
284
+ DynamicSwapInstaller.install_model(transformer, device=device)
285
+ DynamicSwapInstaller.install_model(text_encoder, device=device)
286
+ else:
287
+ text_encoder.to(device)
288
+ text_encoder_2.to(device)
289
+ image_encoder.to(device)
290
+ vae.to(device)
291
+ transformer.to(device)
292
+ print(f"成功将模型移动到{device}设备")
293
+ except Exception as e:
294
+ print(f"移动模型到{device}时出错: {e}")
295
+ print("回退到CPU模式")
296
+ cpu_fallback_mode = True
297
+
298
+ # 保存到全局变量
299
+ models = {
300
+ 'text_encoder': text_encoder,
301
+ 'text_encoder_2': text_encoder_2,
302
+ 'tokenizer': tokenizer,
303
+ 'tokenizer_2': tokenizer_2,
304
+ 'vae': vae,
305
+ 'feature_extractor': feature_extractor,
306
+ 'image_encoder': image_encoder,
307
+ 'transformer': transformer
308
+ }
309
+
310
+ GPU_INITIALIZED = True
311
+ print(f"模型加载完成,运行模式: {'CPU' if cpu_fallback_mode else 'GPU'}")
312
+ return models
313
+ except Exception as e:
314
+ print(f"加载模型过程中发生错误: {e}")
315
+ traceback.print_exc()
316
+
317
+ # 记录更详细的错误信息
318
+ error_info = {
319
+ "error": str(e),
320
+ "traceback": traceback.format_exc(),
321
+ "cuda_available": torch.cuda.is_available(),
322
+ "device": "cpu" if cpu_fallback_mode else "cuda",
323
+ }
324
+
325
+ # 保存错误信息到文件,方便排查
326
+ try:
327
+ with open(os.path.join(outputs_folder, "error_log.txt"), "w") as f:
328
+ f.write(str(error_info))
329
+ except:
330
+ pass
331
+
332
+ # 返回空字典,允许应用继续尝试运行
333
+ cpu_fallback_mode = True
334
+ return {}
335
 
336
 
337
  # 使用Hugging Face Spaces GPU装饰器
338
+ if IN_HF_SPACE and 'spaces' in globals() and GPU_AVAILABLE:
339
+ try:
340
+ @spaces.GPU
341
+ def initialize_models():
342
+ """在@spaces.GPU装饰器内初始化模型"""
343
+ global GPU_INITIALIZED
344
+ try:
345
+ result = load_models()
346
+ GPU_INITIALIZED = True
347
+ return result
348
+ except Exception as e:
349
+ print(f"使用spaces.GPU初始化模型时出错: {e}")
350
+ traceback.print_exc()
351
+ global cpu_fallback_mode
352
+ cpu_fallback_mode = True
353
+ # 不使用装饰器再次尝试
354
+ return load_models()
355
+ except Exception as e:
356
+ print(f"创建spaces.GPU装饰器时出错: {e}")
357
+ # 如果装饰器出错,直接使用非装饰器版本
358
+ def initialize_models():
359
+ return load_models()
360
 
361
 
362
  # 以下函数内部会延迟获取模型
363
  def get_models():
364
  """获取模型,如果尚未加载则加载模型"""
365
+ global models, GPU_INITIALIZED
366
 
367
  # 添加模型加载锁,防止并发加载
368
  model_loading_key = "__model_loading__"
 
373
  print("模型正在加载中,等待...")
374
  # 等待模型加载完成
375
  import time
376
+ start_wait = time.time()
377
  while not models and model_loading_key in globals():
378
  time.sleep(0.5)
379
+ # 超过60秒认为加载失败
380
+ if time.time() - start_wait > 60:
381
+ print("等待模型加载超时")
382
+ break
383
+
384
+ if models:
385
+ return models
386
 
387
  try:
388
  # 设置加载标记
389
  globals()[model_loading_key] = True
390
 
391
+ if IN_HF_SPACE and 'spaces' in globals() and GPU_AVAILABLE and not cpu_fallback_mode:
392
+ try:
393
+ print("使用@spaces.GPU装饰器加载模型")
394
+ models = initialize_models()
395
+ except Exception as e:
396
+ print(f"使用GPU装饰器加载模型失败: {e}")
397
+ print("尝试直接加载模型")
398
+ models = load_models()
399
  else:
400
  print("直接加载模型")
401
+ models = load_models()
402
+ except Exception as e:
403
+ print(f"加载模型时发生未预期的错误: {e}")
404
+ traceback.print_exc()
405
+ # 确保有一个空字典
406
+ models = {}
407
  finally:
408
  # 无论成功与否,都移除加载标记
409
  if model_loading_key in globals():
 
417
 
418
  @torch.no_grad()
419
  def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache):
420
+ global last_update_time
421
+ last_update_time = time.time()
422
+
423
  # 获取模型
424
+ try:
425
+ models = get_models()
426
+ if not models:
427
+ error_msg = "模型加载失败,请检查日志获取详细信息"
428
+ print(error_msg)
429
+ stream.output_queue.push(('error', error_msg))
430
+ stream.output_queue.push(('end', None))
431
+ return
432
+
433
+ text_encoder = models['text_encoder']
434
+ text_encoder_2 = models['text_encoder_2']
435
+ tokenizer = models['tokenizer']
436
+ tokenizer_2 = models['tokenizer_2']
437
+ vae = models['vae']
438
+ feature_extractor = models['feature_extractor']
439
+ image_encoder = models['image_encoder']
440
+ transformer = models['transformer']
441
+ except Exception as e:
442
+ error_msg = f"获取模型时出错: {e}"
443
+ print(error_msg)
444
+ traceback.print_exc()
445
+ stream.output_queue.push(('error', error_msg))
446
+ stream.output_queue.push(('end', None))
447
+ return
448
+
449
+ # 确定设备
450
+ device = 'cuda' if GPU_AVAILABLE and not cpu_fallback_mode else 'cpu'
451
+ print(f"使用设备: {device} 进行推理")
452
+
453
+ # 调整参数以适应CPU模式
454
+ if cpu_fallback_mode:
455
+ print("CPU模式下使用更精简的参数")
456
+ # 减小处理大小以加快CPU处理
457
+ latent_window_size = min(latent_window_size, 5)
458
+ steps = min(steps, 15) # 减少步数
459
+ total_second_length = min(total_second_length, 2.0) # 限制视频长度
460
 
461
  total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
462
  total_latent_sections = int(max(round(total_latent_sections), 1))
 
471
 
472
  try:
473
  # Clean GPU
474
+ if not high_vram and not cpu_fallback_mode:
475
+ try:
476
+ unload_complete_models(
477
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
478
+ )
479
+ except Exception as e:
480
+ print(f"卸载模型时出错: {e}")
481
+ # 继续执行,不中断流程
482
 
483
  # Text encoding
484
+ last_update_time = time.time()
485
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
486
 
487
+ try:
488
+ if not high_vram and not cpu_fallback_mode:
489
+ fake_diffusers_current_device(text_encoder, device)
490
+ load_model_as_complete(text_encoder_2, target_device=device)
491
 
492
+ llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
493
 
494
+ if cfg == 1:
495
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
496
+ else:
497
+ llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
498
 
499
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
500
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
501
+ except Exception as e:
502
+ error_msg = f"文本编码过程出错: {e}"
503
+ print(error_msg)
504
+ traceback.print_exc()
505
+ stream.output_queue.push(('error', error_msg))
506
+ stream.output_queue.push(('end', None))
507
+ return
508
 
509
  # Processing input image
510
+ last_update_time = time.time()
511
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
512
 
513
+ try:
514
+ H, W, C = input_image.shape
515
+ height, width = find_nearest_bucket(H, W, resolution=640)
516
+
517
+ # 如果是CPU模式,缩小处理尺寸
518
+ if cpu_fallback_mode:
519
+ height = min(height, 320)
520
+ width = min(width, 320)
521
+
522
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
523
 
524
+ Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
525
 
526
+ input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
527
+ input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
528
+ except Exception as e:
529
+ error_msg = f"图像处理过程出错: {e}"
530
+ print(error_msg)
531
+ traceback.print_exc()
532
+ stream.output_queue.push(('error', error_msg))
533
+ stream.output_queue.push(('end', None))
534
+ return
535
 
536
  # VAE encoding
537
+ last_update_time = time.time()
538
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
539
 
540
+ try:
541
+ if not high_vram and not cpu_fallback_mode:
542
+ load_model_as_complete(vae, target_device=device)
543
 
544
+ start_latent = vae_encode(input_image_pt, vae)
545
+ except Exception as e:
546
+ error_msg = f"VAE编码过程出错: {e}"
547
+ print(error_msg)
548
+ traceback.print_exc()
549
+ stream.output_queue.push(('error', error_msg))
550
+ stream.output_queue.push(('end', None))
551
+ return
552
 
553
  # CLIP Vision
554
+ last_update_time = time.time()
555
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
556
 
557
+ try:
558
+ if not high_vram and not cpu_fallback_mode:
559
+ load_model_as_complete(image_encoder, target_device=device)
560
 
561
+ image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
562
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
563
+ except Exception as e:
564
+ error_msg = f"CLIP Vision编码过程出错: {e}"
565
+ print(error_msg)
566
+ traceback.print_exc()
567
+ stream.output_queue.push(('error', error_msg))
568
+ stream.output_queue.push(('end', None))
569
+ return
570
 
571
  # Dtype
572
+ try:
573
+ llama_vec = llama_vec.to(transformer.dtype)
574
+ llama_vec_n = llama_vec_n.to(transformer.dtype)
575
+ clip_l_pooler = clip_l_pooler.to(transformer.dtype)
576
+ clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
577
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
578
+ except Exception as e:
579
+ error_msg = f"数据类型转换出错: {e}"
580
+ print(error_msg)
581
+ traceback.print_exc()
582
+ stream.output_queue.push(('error', error_msg))
583
+ stream.output_queue.push(('end', None))
584
+ return
585
 
586
  # Sampling
587
+ last_update_time = time.time()
588
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
589
 
590
  rnd = torch.Generator("cpu").manual_seed(seed)
591
  num_frames = latent_window_size * 4 - 3
592
 
593
+ try:
594
+ history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32).cpu()
595
+ history_pixels = None
596
+ total_generated_latent_frames = 0
597
+ except Exception as e:
598
+ error_msg = f"初始化历史状态出错: {e}"
599
+ print(error_msg)
600
+ traceback.print_exc()
601
+ stream.output_queue.push(('error', error_msg))
602
+ stream.output_queue.push(('end', None))
603
+ return
604
 
605
  latent_paddings = reversed(range(total_latent_sections))
606
 
 
612
  latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
613
 
614
  for latent_padding in latent_paddings:
615
+ last_update_time = time.time()
616
  is_last_section = latent_padding == 0
617
  latent_padding_size = latent_padding * latent_window_size
618
 
 
631
 
632
  print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}')
633
 
634
+ try:
635
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
636
+ clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
637
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
638
 
639
+ clean_latents_pre = start_latent.to(history_latents)
640
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
641
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
642
+ except Exception as e:
643
+ error_msg = f"准备采样数据时出错: {e}"
644
+ print(error_msg)
645
+ traceback.print_exc()
646
+ # 尝试继续下一轮迭代而不是完全终止
647
+ if last_output_filename:
648
+ stream.output_queue.push(('file', last_output_filename))
649
+ continue
650
 
651
+ if not high_vram and not cpu_fallback_mode:
652
+ try:
653
+ unload_complete_models()
654
+ move_model_to_device_with_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation)
655
+ except Exception as e:
656
+ print(f"移动transformer到GPU时出错: {e}")
657
+ # 继续执行,可能影响性能但不必终止
658
 
659
+ if use_teacache and not cpu_fallback_mode:
660
+ try:
661
+ transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
662
+ except Exception as e:
663
+ print(f"初始化teacache时出错: {e}")
664
+ # 禁用teacache并继续
665
+ transformer.initialize_teacache(enable_teacache=False)
666
  else:
667
  transformer.initialize_teacache(enable_teacache=False)
668
 
669
  def callback(d):
670
+ global last_update_time
671
+ last_update_time = time.time()
672
+
673
+ try:
674
+ preview = d['denoised']
675
+ preview = vae_decode_fake(preview)
676
 
677
+ preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
678
+ preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
679
 
680
+ if stream.input_queue.top() == 'end':
681
+ stream.output_queue.push(('end', None))
682
+ raise KeyboardInterrupt('User ends the task.')
683
 
684
+ current_step = d['i'] + 1
685
+ percentage = int(100.0 * current_step / steps)
686
+ hint = f'Sampling {current_step}/{steps}'
687
+ desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). The video is being extended now ...'
688
+ stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
689
+ except Exception as e:
690
+ print(f"回调函数中出错: {e}")
691
+ # 不中断采样过程
692
  return
693
 
694
  try:
695
+ sampling_start_time = time.time()
696
+ print(f"开始采样,设备: {device}, 数据类型: {transformer.dtype}, 使用TeaCache: {use_teacache and not cpu_fallback_mode}")
697
+
698
  generated_latents = sample_hunyuan(
699
  transformer=transformer,
700
  sampler='unipc',
 
713
  negative_prompt_embeds=llama_vec_n,
714
  negative_prompt_embeds_mask=llama_attention_mask_n,
715
  negative_prompt_poolers=clip_l_pooler_n,
716
+ device=device,
717
+ dtype=transformer.dtype,
718
  image_embeddings=image_encoder_last_hidden_state,
719
  latent_indices=latent_indices,
720
  clean_latents=clean_latents,
 
725
  clean_latent_4x_indices=clean_latent_4x_indices,
726
  callback=callback,
727
  )
728
+
729
+ print(f"采样完成,用时: {time.time() - sampling_start_time:.2f}秒")
730
  except Exception as e:
731
  print(f"采样过程中出错: {e}")
732
  traceback.print_exc()
 
734
  # 如果已经有生成的视频,返回最后生成的视频
735
  if last_output_filename:
736
  stream.output_queue.push(('file', last_output_filename))
737
+
738
+ # 创建错误信息
739
+ error_msg = f"采样过程中出错,但已返回部分生成的视频: {e}"
740
+ stream.output_queue.push(('error', error_msg))
741
+ else:
742
+ # 如果没有生成的视频,返回错误信息
743
+ error_msg = f"采样过程中出错,无法生成视频: {e}"
744
+ stream.output_queue.push(('error', error_msg))
745
 
746
  stream.output_queue.push(('end', None))
747
  return
748
 
749
+ try:
750
+ if is_last_section:
751
+ generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
752
 
753
+ total_generated_latent_frames += int(generated_latents.shape[2])
754
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
755
+ except Exception as e:
756
+ error_msg = f"处理生成的潜变量时出错: {e}"
757
+ print(error_msg)
758
+ traceback.print_exc()
759
+
760
+ if last_output_filename:
761
+ stream.output_queue.push(('file', last_output_filename))
762
+ stream.output_queue.push(('error', error_msg))
763
+ stream.output_queue.push(('end', None))
764
+ return
765
 
766
+ if not high_vram and not cpu_fallback_mode:
767
+ try:
768
+ offload_model_from_device_for_memory_preservation(transformer, target_device=device, preserved_memory_gb=8)
769
+ load_model_as_complete(vae, target_device=device)
770
+ except Exception as e:
771
+ print(f"管理模型内存时出错: {e}")
772
+ # 继续执行
773
 
774
+ try:
775
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
776
+ except Exception as e:
777
+ error_msg = f"处理历史潜变量时出错: {e}"
778
+ print(error_msg)
779
+
780
+ if last_output_filename:
781
+ stream.output_queue.push(('file', last_output_filename))
782
+ continue
783
 
784
  try:
785
+ vae_start_time = time.time()
786
+ print(f"开始VAE解码,潜变量形状: {real_history_latents.shape}")
787
+
788
  if history_pixels is None:
789
  history_pixels = vae_decode(real_history_latents, vae).cpu()
790
  else:
 
794
  current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
795
  history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
796
 
797
+ print(f"VAE解码完成,用时: {time.time() - vae_start_time:.2f}秒")
798
+
799
+ if not high_vram and not cpu_fallback_mode:
800
+ try:
801
+ unload_complete_models()
802
+ except Exception as e:
803
+ print(f"卸载模型时出错: {e}")
804
 
805
  output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
806
 
807
+ save_start_time = time.time()
808
  save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
809
+ print(f"保存视频完成,用时: {time.time() - save_start_time:.2f}秒")
810
 
811
  print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
812
 
 
820
  if last_output_filename:
821
  stream.output_queue.push(('file', last_output_filename))
822
 
823
+ # 记录错误信息
824
+ error_msg = f"视频解码或保存过程中出错: {e}"
825
+ stream.output_queue.push(('error', error_msg))
826
+
827
  # 尝试继续下一次迭代
828
  continue
829
 
 
833
  print(f"处理过程中出现错误: {e}")
834
  traceback.print_exc()
835
 
836
+ if not high_vram and not cpu_fallback_mode:
837
  try:
838
  unload_complete_models(
839
  text_encoder, text_encoder_2, image_encoder, vae, transformer
 
844
  # 如果已经有生成的视频,返回最后生成的视频
845
  if last_output_filename:
846
  stream.output_queue.push(('file', last_output_filename))
847
+
848
+ # 返回错误信息
849
+ error_msg = f"处理过程中出现错误: {e}"
850
+ stream.output_queue.push(('error', error_msg))
851
 
852
  # 确保总是返回end信号
853
  stream.output_queue.push(('end', None))
 
872
 
873
  output_filename = None
874
  prev_output_filename = None
875
+ error_message = None
876
 
877
  # 持续检查worker的输出
878
  while True:
 
887
  if flag == 'progress':
888
  preview, desc, html = data
889
  yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
890
+
891
+ if flag == 'error':
892
+ error_message = data
893
+ print(f"收到错误消息: {error_message}")
894
+ # 不立即显示,等待end信号
895
 
896
  if flag == 'end':
897
  # 如果有最后的视频文件,确保返回
898
  if output_filename is None and prev_output_filename is not None:
899
  output_filename = prev_output_filename
900
+
901
+ # 如果有错误消息,创建友好的错误显示
902
+ if error_message:
903
+ error_html = create_error_html(error_message)
904
+ yield output_filename, gr.update(visible=False), gr.update(), error_html, gr.update(interactive=True), gr.update(interactive=False)
905
+ else:
906
+ yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
907
  break
908
  except Exception as e:
909
  print(f"处理输出时出错: {e}")
 
914
 
915
  # 如果有部分生成的视频,返回
916
  if prev_output_filename:
917
+ error_html = create_error_html("处理超时,但已生成部分视频", is_timeout=True)
918
+ yield prev_output_filename, gr.update(visible=False), gr.update(), error_html, gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
919
  else:
920
+ error_html = create_error_html(f"处理超时: {e}", is_timeout=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921
  yield None, gr.update(visible=False), gr.update(), error_html, gr.update(interactive=True), gr.update(interactive=False)
922
  break
923
 
 
925
  print(f"启动处理时出错: {e}")
926
  traceback.print_exc()
927
  error_msg = str(e)
 
 
 
 
 
928
 
929
+ error_html = create_error_html(error_msg)
930
+ yield None, gr.update(visible=False), gr.update(), error_html, gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
 
932
  process = process_with_gpu
933
  else:
 
946
 
947
  output_filename = None
948
  prev_output_filename = None
949
+ error_message = None
950
 
951
  # 持续检查worker的输出
952
  while True:
 
961
  if flag == 'progress':
962
  preview, desc, html = data
963
  yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
964
+
965
+ if flag == 'error':
966
+ error_message = data
967
+ print(f"收到错误消息: {error_message}")
968
+ # 不立即显示,等待end信号
969
 
970
  if flag == 'end':
971
  # 如果有最后的视频文件,确保返回
972
  if output_filename is None and prev_output_filename is not None:
973
  output_filename = prev_output_filename
974
+
975
+ # 如果有错误消息,创建友好的错误显示
976
+ if error_message:
977
+ error_html = create_error_html(error_message)
978
+ yield output_filename, gr.update(visible=False), gr.update(), error_html, gr.update(interactive=True), gr.update(interactive=False)
979
+ else:
980
+ yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
981
  break
982
  except Exception as e:
983
  print(f"处理输出时出错: {e}")
 
988
 
989
  # 如果有部分生成的视频,返回
990
  if prev_output_filename:
991
+ error_html = create_error_html("处理超时,但已生成部分视频", is_timeout=True)
992
+ yield prev_output_filename, gr.update(visible=False), gr.update(), error_html, gr.update(interactive=True), gr.update(interactive=False)
993
+ else:
994
+ error_html = create_error_html(f"处理超时: {e}", is_timeout=True)
995
+ yield None, gr.update(visible=False), gr.update(), error_html, gr.update(interactive=True), gr.update(interactive=False)
996
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
997
 
998
  except Exception as e:
999
  print(f"启动处理时出错: {e}")
1000
  traceback.print_exc()
1001
  error_msg = str(e)
 
 
 
 
 
1002
 
1003
+ error_html = create_error_html(error_msg)
1004
+ yield None, gr.update(visible=False), gr.update(), error_html, gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1005
 
1006
 
1007
  def end_process():
 
1465
  end_button.click(fn=end_process)
1466
 
1467
 
1468
+ block.launch()
1469
+
1470
+ # 创建友好的错误显示HTML
1471
+ def create_error_html(error_msg, is_timeout=False):
1472
+ """创建双语错误消息HTML"""
1473
+ # 提供更友好的中英文双语错误信息
1474
+ en_msg = ""
1475
+ zh_msg = ""
1476
+
1477
+ if is_timeout:
1478
+ en_msg = "Processing timed out, but partial video may have been generated" if "部分视频" in error_msg else f"Processing timed out: {error_msg}"
1479
+ zh_msg = "处理超时,但已生成部分视频" if "部分视频" in error_msg else f"处理超时: {error_msg}"
1480
+ elif "模型加载失败" in error_msg:
1481
+ en_msg = "Failed to load models. The Space may be experiencing high traffic or GPU issues."
1482
+ zh_msg = "模型加载失败,可能是Space流量过高或GPU资源不足。"
1483
+ elif "GPU" in error_msg or "CUDA" in error_msg or "内存" in error_msg or "memory" in error_msg:
1484
+ en_msg = "GPU memory insufficient or GPU error. Try increasing GPU memory preservation value or reduce video length."
1485
+ zh_msg = "GPU内存不足或GPU错误,请尝试增加GPU推理保留内存值或降低视频长度。"
1486
+ elif "采样过程中出错" in error_msg:
1487
+ if "部分" in error_msg:
1488
+ en_msg = "Error during sampling process, but partial video has been generated."
1489
+ zh_msg = "采样过程中出错,但已生成部分视频。"
1490
+ else:
1491
+ en_msg = "Error during sampling process. Unable to generate video."
1492
+ zh_msg = "采样过程中出错,无法生成视频。"
1493
+ elif "模型下载超时" in error_msg or "网络连接不稳定" in error_msg or "ReadTimeoutError" in error_msg or "ConnectionError" in error_msg:
1494
+ en_msg = "Network connection is unstable, model download timed out. Please try again later."
1495
+ zh_msg = "网络连接不稳定,模型下载超时。请稍后再试。"
1496
+ elif "VAE" in error_msg or "解码" in error_msg or "decode" in error_msg:
1497
+ en_msg = "Error during video decoding or saving process. Try again with a different seed."
1498
+ zh_msg = "视频解码或保存过程中出错,请尝试使用不同的随机种子。"
1499
+ else:
1500
+ en_msg = f"Processing error: {error_msg}"
1501
+ zh_msg = f"处理过程出错: {error_msg}"
1502
+
1503
+ # 创建双语错误消息HTML
1504
+ return f"""
1505
+ <div id="error-container" class="error-message">
1506
+ <div class="error-msg-en" data-lang="en">{en_msg}</div>
1507
+ <div class="error-msg-zh" data-lang="zh">{zh_msg}</div>
1508
+ </div>
1509
+ <script>
1510
+ // 根据当前语言显示相应的错误消息
1511
+ (function() {{
1512
+ const errorContainer = document.getElementById('error-container');
1513
+ if (errorContainer) {{
1514
+ const currentLang = window.currentLang || 'en'; // 默认英语
1515
+ const errMsgs = errorContainer.querySelectorAll('[data-lang]');
1516
+ errMsgs.forEach(msg => {{
1517
+ msg.style.display = msg.getAttribute('data-lang') === currentLang ? 'block' : 'none';
1518
+ }});
1519
+ }}
1520
+ }})();
1521
+ </script>
1522
+ """